From 43367e4b68e8cd9e10f374c867c408409d4060cd Mon Sep 17 00:00:00 2001 From: WeiXin Date: Fri, 2 Apr 2021 17:42:30 +0800 Subject: [PATCH] support save/load single tensor (#31756) * support save/load single tensor * compatibility modification according to unnittest * Some python2.7 don't have 'copyreg' modules * Handle a syntax error. * Dealing with compatibility problems on Mac. * Dealing with compatibility problems on Mac. * edit unittest to improve coverage. * Modify the code according to the review comments * Reduce redundant code. * support for static graph loading dygraph state_dict * edit code according to CI * edit unittest * edit unnittest * delete redundant file * edit code according to Comments * edit english doc * edit english doc * edit English DOC. * get/set_tensor->get/set_value; return_numpy=False * get/set_tensor->get/set_value; return_numpy=False * edit unnittest * edit unnittest * polish code. --- python/paddle/fluid/dygraph/layers.py | 8 +- python/paddle/fluid/framework.py | 352 +++++++++++++++-- python/paddle/fluid/io.py | 82 +++- .../fluid/tests/unittests/CMakeLists.txt | 2 +- .../unittests/test_imperative_save_load_v2.py | 2 +- .../tests/unittests/test_paddle_save_load.py | 254 +++++++++++- .../unittests/test_static_save_load_large.py | 17 +- .../fluid/tests/unittests/test_variable.py | 1 - python/paddle/framework/io.py | 362 +++++++++++++++++- .../static_mode_white_list.cpython-37.pyc | Bin 0 -> 20217 bytes tools/static_mode_white_list.pyc | Bin 21803 -> 0 bytes 11 files changed, 998 insertions(+), 82 deletions(-) create mode 100644 tools/__pycache__/static_mode_white_list.cpython-37.pyc delete mode 100644 tools/static_mode_white_list.pyc diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 3df0c60852..36637abc6d 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -22,6 +22,8 @@ import copy import weakref import warnings from copy import deepcopy +import inspect + import paddle from . import parallel_helper @@ -1294,10 +1296,12 @@ class Layer(core.Layer): if state is None: raise ValueError("{} is not found in the provided dict.".format( key)) - if list(state.shape) != list(param.shape): + state_shape = state.shape() if inspect.ismethod( + state.shape) else state.shape + if list(state_shape) != list(param.shape): raise ValueError( "{} receives a shape {}, but the expected shape is {}.". - format(key, list(state.shape), list(param.shape))) + format(key, list(state_shape), list(param.shape))) return param, state matched_param_state = [] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index be795b9e59..d5c01d20a9 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -24,6 +24,7 @@ import re import traceback import six import copy +from types import MethodType, FunctionType import numpy as np import subprocess @@ -1183,37 +1184,6 @@ class Variable(object): """ pass - @fake_interface_only - def set_value(self, value): - """ - **Notes**: - **This API is ONLY available in Dygraph mode** - - Set a new value for this Variable. - - Args: - value (Variable|np.ndarray): the new value. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - from paddle.fluid.dygraph.base import to_variable - from paddle.fluid.dygraph import Linear - import numpy as np - - data = np.ones([3, 1024], dtype='float32') - with fluid.dygraph.guard(): - linear = fluid.dygraph.Linear(1024, 4) - t = to_variable(data) - linear(t) # call with default weight - custom_weight = np.random.randn(1024, 4).astype("float32") - linear.weight.set_value(custom_weight) # change existing weight - out = linear(t) # call with different weight - - """ - pass - @fake_interface_only def backward(self, retain_graph=False): """ @@ -2011,6 +1981,159 @@ class Variable(object): return self + def get_value(self, scope=None): + """ + Get the value of variable in given scope. + + Args: + scope(Scope, optional) : If `scope` is None, it will be set to global scope + obtained through 'paddle.static.global_scope()'. Otherwise, use `scope`. + Default: None + + Returns: + Tensor: the value in given scope. + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + import numpy as np + + paddle.enable_static() + + x = static.data(name="x", shape=[10, 10], dtype='float32') + + y = static.nn.fc(x, 10, name='fc') + place = paddle.CPUPlace() + exe = static.Executor(place) + prog = paddle.static.default_main_program() + exe.run(static.default_startup_program()) + inputs = np.ones((10, 10), dtype='float32') + exe.run(prog, feed={'x': inputs}, fetch_list=[y, ]) + path = 'temp/tensor_' + for var in prog.list_vars(): + if var.persistable: + t = var.get_value() + paddle.save(t, path+var.name+'.pdtensor') + + for var in prog.list_vars(): + if var.persistable: + t_load = paddle.load(path+var.name+'.pdtensor') + var.set_value(t_load) + """ + # The 'framework' is a low-level module, and 'executor' + # can not be imported at the begainning of this file. + # Therefore, the above two modules are dynamically imported. + from .executor import global_scope + if scope is not None and not isinstance(scope, core._Scope): + raise TypeError( + "`scope` should be None or `paddle.static.Scope` type, but received {}.". + format(type(scope))) + + if scope is None: + scope = global_scope() + var_temp = scope.find_var(self.name) + if var_temp is None: + raise ValueError("Can not find Variable '{}' in the Scope.".format( + self.name)) + t = var_temp.get_tensor() + return t + + def set_value(self, value, scope=None): + ''' + Set the value to the tensor in given scope. + + Args: + value(Tensor/ndarray) : The value to be set. + scope(Scope, optional) : If `scope` is None, it will be set to global scope + obtained through 'paddle.static.global_scope()'. Otherwise, use `scope`. + Default: None + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + import numpy as np + + paddle.enable_static() + + x = static.data(name="x", shape=[10, 10], dtype='float32') + + y = static.nn.fc(x, 10, name='fc') + place = paddle.CPUPlace() + exe = static.Executor(place) + prog = paddle.static.default_main_program() + exe.run(static.default_startup_program()) + inputs = np.ones((10, 10), dtype='float32') + exe.run(prog, feed={'x': inputs}, fetch_list=[y, ]) + path = 'temp/tensor_' + for var in prog.list_vars(): + if var.persistable: + t = var.get_value() + paddle.save(t, path+var.name+'.pdtensor') + + for var in prog.list_vars(): + if var.persistable: + t_load = paddle.load(path+var.name+'.pdtensor') + var.set_value(t_load) + ''' + + # The 'framework' is a low-level module, and 'executor' + # can not be imported at the begainning of this file. + # Therefore, the above two modules are dynamically imported. + from .executor import global_scope + + if not (isinstance(value, np.ndarray) or hasattr(value, '__array__')): + raise TypeError( + "`value` should be `numpy.ndarray` or `LoDTensor`, but received {}.". + format(type(value))) + + if scope is not None and not isinstance(scope, core._Scope): + raise TypeError( + "`scope` should be None or `paddle.static.Scope` type, but received {}.". + format(type(scope))) + + if scope is None: + scope = global_scope() + + var_temp = scope.find_var(self.name) + if var_temp is None: + raise ValueError("Can not find Variable '{}' in the Scope.".format( + self.name)) + + t = var_temp.get_tensor() + + if hasattr(value, 'shape'): + if isinstance(value.shape, (MethodType, FunctionType)): + value_shape = value.shape() + else: + value_shape = value.shape + if list(t.shape()) != list(value_shape): + raise ValueError( + "{} expected a shape {}, but the received shape is {}.". + format(self.name, list(t.shape()), list(value_shape))) + + p = t._place() + if p.is_cpu_place(): + place = core.CPUPlace() + elif p.is_cuda_pinned_place(): + place = core.CUDAPinnedPlace() + elif p.is_xpu_place(): + p = core.Place() + p.set_place(t._place()) + place = core.XPUPlace(p.xpu_device_id()) + else: + p = core.Place() + p.set_place(t._place()) + place = core.CUDAPlace(p.gpu_device_id()) + + t.set(value, place) + def get_all_op_protos(): """ @@ -5319,6 +5442,173 @@ class Program(object): parameters.extend(each_block.all_parameters()) return parameters + def state_dict(self, mode='all', scope=None): + """ + Get parameters and persistable buffers of program as a dict. The key is the name of the parameter or the name of the buffer. + The value is the tensor of this variable in the given scope. + + .. note:: + This function MUST called after run start_up_program + + Args: + mode(str, optional): Source of the obtained parameters and buffers. + 'opt' : The return value only contains the variable in the optimizer. + 'param' : The return value only contains the variable in the network, not the variable in the optimizer. + 'all' : The return value contains the variable in the network and optimizer. + Default: 'all' + scope(Scope, optional) : If scope is None, state_dict will be set to global scope + obtained through 'paddle.static.global_scope()'. Otherwise, value will be set to scope. + Default: None + + Retruns: + dict: a dict contains the parameters and persistable buffers. + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + + paddle.enable_static() + + x = static.data(name="x", shape=[10, 10], dtype='float32') + y = static.nn.fc(x, 10) + z = static.nn.fc(y, 10) + + place = paddle.CPUPlace() + exe = static.Executor(place) + exe.run(static.default_startup_program()) + prog = static.default_main_program() + + path = "./temp/model.pdparams" + paddle.save(prog.state_dict(), path) + """ + # The 'framework' is a low-level module, and 'executor' + # can not be imported at the begainning of this file. + # Therefore, the above two modules are dynamically imported. + from .executor import global_scope + if scope is not None and not isinstance(scope, core._Scope): + raise TypeError( + "`scope` should be None or `paddle.static.Scope'` type, but received {}.". + format(type(scope))) + + if scope is None: + scope = global_scope() + + if not isinstance(mode, str): + raise TypeError("Type of `mode` should be string, but received {}.". + format(type(mode))) + + def is_parameter(var): + return isinstance(var, Parameter) + + def is_persistable(var): + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.READER: + return False + return var.persistable + + def is_belong_to_optimizer(var): + if not (isinstance(var, Parameter) or var.desc.need_check_feed()): + return is_persistable(var) + return False + + def condition(var): + + if mode == 'param': + return is_parameter(var) + elif mode == 'opt': + return is_belong_to_optimizer(var) + elif mode == 'all': + return is_parameter(var) or is_belong_to_optimizer(var) + else: + raise ValueError( + "`mode` string should be 'param', 'opt' or 'all', but received {}.". + format(mode)) + + var_list = filter(condition, self.list_vars()) + + state_dict = dict() + for var in var_list: + var_temp = scope.find_var(var.name) + if var_temp is None: + raise ValueError( + "Can not find Variable '{}' in the scope. Make sure it is initialized". + format(var.name)) + state_dict[var.name] = var_temp.get_tensor() + + return state_dict + + def set_state_dict(self, state_dict, scope=None): + """ + Set parameters and persistable buffers in state_dict to program. + An exception will throw if shape or dtype of the parameters is not match. + + .. note:: + This function MUST called after run start_up_program + + Args: + state_dict(dict): the dict store parameters and persistable buffers. + The key is the name of the parameter or the name of the buffer. + The value is the tensor of this variable in the given scope. + scope(Scope, optional) : If scope is None, state_dict will be set to global scope + obtained through 'paddle.static.global_scope()'. Otherwise, value will be set to scope. + Default: None + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + + paddle.enable_static() + + x = static.data(name="x", shape=[10, 10], dtype='float32') + y = static.nn.fc(x, 10) + z = static.nn.fc(y, 10) + + place = paddle.CPUPlace() + exe = static.Executor(place) + exe.run(static.default_startup_program()) + prog = static.default_main_program() + + path = "./temp/model.pdparams" + paddle.save(prog.state_dict(), path) + state_dict_load = paddle.load(path) + prog.set_state_dict(state_dict_load) + """ + + if not isinstance(state_dict, dict): + raise TypeError( + "Type of `state_dict` should be dict, but received {}.".format( + type(state_dict))) + + vars_dict = {var.name: var for var in self.list_vars()} + condition = True if 'StructuredToParameterName@@' in state_dict else False + for name, value in state_dict.items(): + if condition: + if name == "StructuredToParameterName@@": + continue + if name in state_dict['StructuredToParameterName@@']: + name = state_dict['StructuredToParameterName@@'][name] + if name in vars_dict: + try: + vars_dict[name].set_value(value, scope) + except ValueError as err: + warnings.warn( + ("Skip loading for '{}'. ".format(name) + str(err))) + except TypeError as err: + warnings.warn( + ("Skip loading for '{}'. ".format(name) + str(err))) + else: + warnings.warn(( + "Skip loading for '{0}'. Because '{0}' not in the program.". + format(name))) + @six.add_metaclass(ParameterMetaClass) class Parameter(Variable): diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 9cca3e16de..cfb4b12599 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1765,7 +1765,30 @@ def _pack_loaded_dict(load_obj): @static_only -def save(program, model_path, pickle_protocol=2): +def _legacy_save(param_dict, model_path, protocol=2): + def get_tensor(var): + if isinstance(var, core.VarBase): + return var.numpy() + elif isinstance(var, core.LoDTensor): + return np.array(var) + return var + + param_dict = {name: get_tensor(param_dict[name]) for name in param_dict} + + # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' + if sys.platform == 'darwin' and sys.version_info.major == 3: + pickle_bytes = pickle.dumps(param_dict, protocol=protocol) + with open(model_path, 'wb') as f: + max_bytes = 2**30 + for i in range(0, len(pickle_bytes), max_bytes): + f.write(pickle_bytes[i:i + max_bytes]) + else: + with open(model_path, 'wb') as f: + pickle.dump(param_dict, f, protocol=protocol) + + +@static_only +def save(program, model_path, protocol=2, **configs): """ :api_attr: Static Graph @@ -1778,8 +1801,9 @@ def save(program, model_path, pickle_protocol=2): Args: program(Program) : The program to saved. model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised - pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. + protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. Default: 2 + configs(dict, optional) : optional keyword arguments. Returns: None @@ -1807,14 +1831,19 @@ def save(program, model_path, pickle_protocol=2): base_name = os.path.basename(model_path) assert base_name != "", \ "The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string." + if 'pickle_protocol' in configs: + protocol = configs['pickle_protocol'] + warnings.warn( + "'pickle_protocol' is a deprecated argument. Please use 'protocol' instead." + ) - if not isinstance(pickle_protocol, int): + if not isinstance(protocol, int): raise ValueError("The 'protocol' MUST be `int`, but received {}".format( - type(pickle_protocol))) + type(protocol))) - if pickle_protocol < 2 or pickle_protocol > 4: + if protocol < 2 or protocol > 4: raise ValueError("Expected 1<'protocol'<5, but received protocol={}". - format(pickle_protocol)) + format(protocol)) dir_name = os.path.dirname(model_path) if dir_name and not os.path.exists(dir_name): @@ -1827,26 +1856,25 @@ def save(program, model_path, pickle_protocol=2): parameter_list = list(filter(is_parameter, program.list_vars())) param_dict = {p.name: get_tensor(p) for p in parameter_list} - param_dict = _unpack_saved_dict(param_dict, pickle_protocol) + param_dict = _unpack_saved_dict(param_dict, protocol) - # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' - if sys.platform == 'darwin' and sys.version_info.major == 3 and ( - sys.version_info.minor == 5 or sys.version_info.minor == 6): - pickle_bytes = pickle.dumps(param_dict, protocol=pickle_protocol) + # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' + if sys.platform == 'darwin' and sys.version_info.major == 3: + pickle_bytes = pickle.dumps(param_dict, protocol=protocol) with open(model_path + ".pdparams", 'wb') as f: max_bytes = 2**30 for i in range(0, len(pickle_bytes), max_bytes): f.write(pickle_bytes[i:i + max_bytes]) else: with open(model_path + ".pdparams", 'wb') as f: - pickle.dump(param_dict, f, protocol=pickle_protocol) + pickle.dump(param_dict, f, protocol=protocol) optimizer_var_list = list( filter(is_belong_to_optimizer, program.list_vars())) opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list} with open(model_path + ".pdopt", 'wb') as f: - pickle.dump(opt_dict, f, protocol=pickle_protocol) + pickle.dump(opt_dict, f, protocol=protocol) main_program = program.clone() program.desc.flush() @@ -1857,6 +1885,17 @@ def save(program, model_path, pickle_protocol=2): f.write(program.desc.serialize_to_string()) +def _pickle_loads_mac(path, f): + pickle_bytes = bytearray(0) + file_size = os.path.getsize(path) + max_bytes = 2**30 + for _ in range(0, file_size, max_bytes): + pickle_bytes += f.read(max_bytes) + load_result = pickle.loads(pickle_bytes) if six.PY2 else pickle.loads( + pickle_bytes, encoding='latin1') + return load_result + + @static_only def load(program, model_path, executor=None, var_list=None): """ @@ -2016,8 +2055,13 @@ def load(program, model_path, executor=None, var_list=None): global_scope(), executor._default_executor) with open(parameter_file_name, 'rb') as f: - load_dict = pickle.load(f) if six.PY2 else pickle.load( - f, encoding='latin1') + + # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' + if sys.platform == 'darwin' and sys.version_info.major == 3: + load_dict = _pickle_loads_mac(parameter_file_name, f) + else: + load_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') load_dict = _pack_loaded_dict(load_dict) for v in parameter_list: assert v.name in load_dict, \ @@ -2196,8 +2240,12 @@ def load_program_state(model_path, var_list=None): "Parameter file [{}] not exits".format(parameter_file_name) with open(parameter_file_name, 'rb') as f: - para_dict = pickle.load(f) if six.PY2 else pickle.load( - f, encoding='latin1') + # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' + if sys.platform == 'darwin' and sys.version_info.major == 3: + para_dict = _pickle_loads_mac(parameter_file_name, f) + else: + para_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') para_dict = _pack_loaded_dict(para_dict) opt_file_name = model_prefix + ".pdopt" diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 28f5177c20..add3bbee41 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -726,7 +726,7 @@ if (WIN32) set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250) else() set_tests_properties(test_static_save_load_large PROPERTIES TIMEOUT 600) - set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 150) + set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250) endif() set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 120) set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py b/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py index 672ffa9d39..9f0dcdb4d8 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py @@ -930,7 +930,7 @@ class TestDygraphPtbRnn(unittest.TestCase): paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy.pdparams')) para_state_dict = paddle.load( - os.path.join('saved_dy', 'emb_dy.pdparams')) + os.path.join('saved_dy', 'emb_dy.pdparams'), return_numpy=True) para_state_dict['weight'] = np.expand_dims( para_state_dict['weight'], axis=-1) diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index 06f63d1416..b58d63969a 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -18,10 +18,15 @@ import unittest import numpy as np import os import sys +import six import paddle import paddle.nn as nn import paddle.optimizer as opt +import paddle.fluid as fluid +from paddle.fluid.optimizer import Adam +import paddle.fluid.framework as framework +from test_imperative_base import new_program_scope BATCH_SIZE = 16 BATCH_NUM = 4 @@ -31,7 +36,10 @@ SEED = 10 IMAGE_SIZE = 784 CLASS_NUM = 10 -LARGE_PARAM = 2**26 +if six.PY2: + LARGE_PARAM = 2**2 +else: + LARGE_PARAM = 2**26 def random_batch_reader(): @@ -95,15 +103,22 @@ class TestSaveLoadLargeParameters(unittest.TestCase): path = os.path.join("test_paddle_save_load_large_param_save", "layer.pdparams") - paddle.save(layer.state_dict(), path) + if six.PY2: + protocol = 2 + else: + protocol = 4 + paddle.save(save_dict, path, protocol=protocol) dict_load = paddle.load(path) # compare results before and after saving for key, value in save_dict.items(): - self.assertTrue(np.array_equal(dict_load[key], value.numpy())) + self.assertTrue( + np.array_equal(dict_load[key].numpy(), value.numpy())) class TestSaveLoadPickle(unittest.TestCase): def test_pickle_protocol(self): + # enable dygraph mode + paddle.disable_static() # create network layer = LinearNet() save_dict = layer.state_dict() @@ -124,11 +139,236 @@ class TestSaveLoadPickle(unittest.TestCase): if sys.version_info.major >= 3 and sys.version_info.minor >= 4: protocols += [3, 4] for protocol in protocols: - paddle.save(save_dict, path, protocol) + paddle.save(save_dict, path, pickle_protocol=protocol) dict_load = paddle.load(path) # compare results before and after saving for key, value in save_dict.items(): - self.assertTrue(np.array_equal(dict_load[key], value.numpy())) + self.assertTrue( + np.array_equal(dict_load[key].numpy(), value.numpy())) + + +class TestSaveLoadAny(unittest.TestCase): + def set_zero(self, prog, place, scope=None): + if scope is None: + scope = fluid.global_scope() + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + ten = scope.find_var(var.name).get_tensor() + if ten is not None: + ten.set(np.zeros_like(np.array(ten)), place) + new_t = np.array(scope.find_var(var.name).get_tensor()) + self.assertTrue(np.sum(np.abs(new_t)) == 0) + + def replace_static_save(self, program, model_path, pickle_protocol=2): + with self.assertRaises(TypeError): + program.state_dict(1) + with self.assertRaises(TypeError): + program.state_dict(scope=1) + with self.assertRaises(ValueError): + program.state_dict('x') + state_dict_param = program.state_dict('param') + paddle.save(state_dict_param, model_path + '.pdparams') + state_dict_opt = program.state_dict('opt') + paddle.save(state_dict_opt, model_path + '.pdopt') + state_dict_all = program.state_dict() + paddle.save(state_dict_opt, model_path + '.pdall') + + def replace_static_load(self, program, model_path): + with self.assertRaises(TypeError): + program.set_state_dict(1) + state_dict_param = paddle.load(model_path + '.pdparams') + state_dict_param['fake_var_name.@@'] = np.random.randn(1, 2) + state_dict_param['static_x'] = 'UserWarning' + program.set_state_dict(state_dict_param) + state_dict_param['static_x'] = np.random.randn(1, 2) + program.set_state_dict(state_dict_param) + program.set_state_dict(state_dict_param) + state_dict_opt = paddle.load(model_path + '.pdopt') + program.set_state_dict(state_dict_opt) + + def test_replace_static_save_load(self): + paddle.enable_static() + with new_program_scope(): + x = paddle.static.data( + name="static_x", shape=[None, IMAGE_SIZE], dtype='float32') + z = paddle.static.nn.fc(x, 10) + z = paddle.static.nn.fc(z, 10, bias_attr=False) + loss = fluid.layers.reduce_mean(z) + opt = Adam(learning_rate=1e-3) + opt.minimize(loss) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + prog = paddle.static.default_main_program() + fake_inputs = np.random.randn(2, IMAGE_SIZE).astype('float32') + exe.run(prog, feed={'static_x': fake_inputs}, fetch_list=[loss]) + base_map = {} + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_map[var.name] = t + path = os.path.join("test_replace_static_save_load", "model") + # paddle.save, legacy paddle.fluid.load + self.replace_static_save(prog, path) + self.set_zero(prog, place) + paddle.fluid.io.load(prog, path) + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, np.array(base_t))) + # legacy paddle.fluid.save, paddle.load + paddle.fluid.io.save(prog, path) + self.set_zero(prog, place) + self.replace_static_load(prog, path) + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, base_t)) + # test for return tensor + path_vars = 'test_replace_save_load_return_tensor_static/model' + for var in prog.list_vars(): + if var.persistable: + tensor = var.get_value(fluid.global_scope()) + paddle.save(tensor, os.path.join(path_vars, var.name)) + with self.assertRaises(TypeError): + var.get_value('fluid.global_scope()') + with self.assertRaises(ValueError): + x.get_value() + with self.assertRaises(TypeError): + x.set_value('1') + fake_data = np.zeros([3, 2, 1, 2, 3]) + with self.assertRaises(TypeError): + x.set_value(fake_data, '1') + with self.assertRaises(ValueError): + x.set_value(fake_data) + with self.assertRaises(ValueError): + var.set_value(fake_data) + # set var to zero + self.set_zero(prog, place) + for var in prog.list_vars(): + if var.persistable: + tensor = paddle.load( + os.path.join(path_vars, var.name), return_numpy=False) + var.set_value(tensor) + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, base_t)) + + def test_paddle_save_load_v2(self): + paddle.disable_static() + layer = LinearNet() + state_dict = layer.state_dict() + path = 'paddle_save_load_v2/model.pdparams' + with self.assertRaises(TypeError): + paddle.save(state_dict, path, use_binary_format='False') + # legacy paddle.save, paddle.load + paddle.framework.io._legacy_save(state_dict, path) + load_dict_tensor = paddle.load(path, return_numpy=False) + # legacy paddle.load, paddle.save + paddle.save(state_dict, path) + load_dict_np = paddle.framework.io._legacy_load(path) + for k, v in state_dict.items(): + self.assertTrue( + np.array_equal(v.numpy(), load_dict_tensor[k].numpy())) + self.assertTrue(np.array_equal(v.numpy(), load_dict_np[k])) + + def test_single_pickle_var_dygraph(self): + # enable dygraph mode + paddle.disable_static() + layer = LinearNet() + path = 'paddle_save_load_v2/var_dygraph' + tensor = layer._linear.weight + with self.assertRaises(ValueError): + paddle.save(tensor, path, pickle_protocol='3') + with self.assertRaises(ValueError): + paddle.save(tensor, path, pickle_protocol=5) + paddle.save(tensor, path) + t_dygraph = paddle.load(path) + np_dygraph = paddle.load(path, return_numpy=True) + self.assertTrue(isinstance(t_dygraph, paddle.fluid.core.VarBase)) + self.assertTrue(np.array_equal(tensor.numpy(), np_dygraph)) + self.assertTrue(np.array_equal(tensor.numpy(), t_dygraph.numpy())) + paddle.enable_static() + lod_static = paddle.load(path) + np_static = paddle.load(path, return_numpy=True) + self.assertTrue(isinstance(lod_static, paddle.fluid.core.LoDTensor)) + self.assertTrue(np.array_equal(tensor.numpy(), np_static)) + self.assertTrue(np.array_equal(tensor.numpy(), np.array(lod_static))) + + def test_single_pickle_var_static(self): + # enable static mode + paddle.enable_static() + with new_program_scope(): + # create network + x = paddle.static.data( + name="x", shape=[None, IMAGE_SIZE], dtype='float32') + z = paddle.static.nn.fc(x, 128) + loss = fluid.layers.reduce_mean(z) + place = fluid.CPUPlace( + ) if not paddle.fluid.core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + prog = paddle.static.default_main_program() + for var in prog.list_vars(): + if list(var.shape) == [IMAGE_SIZE, 128]: + tensor = var.get_value() + break + scope = fluid.global_scope() + origin_tensor = np.array(tensor) + path = 'test_single_pickle_var_static/var' + paddle.save(tensor, path) + self.set_zero(prog, place, scope) + # static load + lod_static = paddle.load(path) + np_static = paddle.load(path, return_numpy=True) + # set_tensor(np.ndarray) + var.set_value(np_static, scope) + self.assertTrue(np.array_equal(origin_tensor, np.array(tensor))) + # set_tensor(LoDTensor) + self.set_zero(prog, place, scope) + var.set_value(lod_static, scope) + self.assertTrue(np.array_equal(origin_tensor, np.array(tensor))) + # enable dygraph mode + paddle.disable_static() + var_dygraph = paddle.load(path) + np_dygraph = paddle.load(path, return_numpy=True) + self.assertTrue(np.array_equal(np.array(tensor), np_dygraph)) + self.assertTrue(np.array_equal(np.array(tensor), var_dygraph.numpy())) + + def test_dygraph_save_static_load(self): + inps = np.random.randn(1, IMAGE_SIZE).astype('float32') + path = 'test_dygraph_save_static_load/dy-static.pdparams' + paddle.disable_static() + with paddle.utils.unique_name.guard(): + layer = LinearNet() + state_dict_dy = layer.state_dict() + paddle.save(state_dict_dy, path) + paddle.enable_static() + with new_program_scope(): + layer = LinearNet() + data = paddle.static.data( + name='x_static_save', shape=(None, IMAGE_SIZE), dtype='float32') + y_static = layer(data) + program = paddle.static.default_main_program() + place = fluid.CPUPlace( + ) if not paddle.fluid.core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(paddle.static.default_startup_program()) + state_dict = paddle.load(path, keep_name_table=True) + program.set_state_dict(state_dict) + state_dict_param = program.state_dict("param") + for name, tensor in state_dict_dy.items(): + self.assertTrue( + np.array_equal(tensor.numpy(), + np.array(state_dict_param[tensor.name]))) class TestSaveLoad(unittest.TestCase): @@ -158,7 +398,9 @@ class TestSaveLoad(unittest.TestCase): def check_load_state_dict(self, orig_dict, load_dict): for var_name, value in orig_dict.items(): - self.assertTrue(np.array_equal(value.numpy(), load_dict[var_name])) + load_value = load_dict[var_name].numpy() if hasattr( + load_dict[var_name], 'numpy') else np.array(load_dict[var_name]) + self.assertTrue(np.array_equal(value.numpy(), load_value)) def test_save_load(self): layer, opt = self.build_and_train_model() diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load_large.py b/python/paddle/fluid/tests/unittests/test_static_save_load_large.py index 08413d711b..c5dc98af5c 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load_large.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load_large.py @@ -25,12 +25,17 @@ import six import pickle import os +# Python2.x no longer supports saving and loading large parameters. +if six.PY2: + LARGE_PARAM = 2 +else: + LARGE_PARAM = 2**26 + class TestStaticSaveLoadLargeParameters(unittest.TestCase): def test_large_parameters_static_save(self): # enable static mode paddle.enable_static() - LARGE_PARAM = 2**26 with new_program_scope(): # create network x = paddle.static.data( @@ -54,7 +59,11 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): path = os.path.join("test_static_save_load_large_param", "static_save") - paddle.fluid.save(prog, path) + if six.PY2: + protocol = 2 + else: + protocol = 4 + paddle.fluid.save(prog, path, pickle_protocol=protocol) # set var to zero for var in prog.list_vars(): if isinstance(var, framework.Parameter) or var.persistable: @@ -92,3 +101,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): .get_tensor()) base_t = base_map[var.name] self.assertTrue(np.array_equal(new_t, base_t)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 8d5ab0a5be..690ac46e56 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -190,7 +190,6 @@ class TestVariable(unittest.TestCase): with fluid.dygraph.guard(): self.assertRaises(AssertionError, var.detach) self.assertRaises(AssertionError, var.numpy) - self.assertRaises(AssertionError, var.set_value, None) self.assertRaises(AssertionError, var.backward) self.assertRaises(AssertionError, var.gradient) self.assertRaises(AssertionError, var.clear_gradient) diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 3d93bed32e..3b953efab7 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -22,13 +22,18 @@ import warnings import sys import numpy as np +if not six.PY2: + import copyreg + import paddle # deprecated module import from paddle import fluid from paddle.fluid import core -from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict -from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer +from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict, _pickle_loads_mac +from paddle.fluid.io import _legacy_save as _legacy_static_save + +from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer, in_dygraph_mode, ParamBase, _current_expected_place from paddle.fluid.dygraph.jit import _SaveLoadConfig from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX @@ -181,7 +186,9 @@ def _build_load_path_and_config(path, config): def _parse_load_config(configs): - supported_configs = ['model_filename', 'params_filename', 'keep_name_table'] + supported_configs = [ + 'model_filename', 'params_filename', 'keep_name_table', 'return_numpy' + ] # input check for key in configs: @@ -195,16 +202,158 @@ def _parse_load_config(configs): inner_config.model_filename = configs.get('model_filename', None) inner_config.params_filename = configs.get('params_filename', None) inner_config.keep_name_table = configs.get('keep_name_table', None) + inner_config.return_numpy = configs.get('return_numpy', False) return inner_config -def save(obj, path, pickle_protocol=2): +def _parse_save_config(configs): + supported_configs = ['use_binary_format', 'pickle_protocol'] + + # input check + for key in configs: + if key not in supported_configs: + raise ValueError( + "The additional config (%s) of `paddle.save` is not supported." + % key) + + # construct inner config + inner_config = _SaveLoadConfig() + inner_config.use_binary_format = configs.get('use_binary_format', False) + inner_config.pickle_protocol = configs.get('pickle_protocol', None) + + return inner_config + + +def _pickle_save(obj, f, protocol): + # TODO(weixin):add support for BytesIO. + if not isinstance(protocol, int): + raise ValueError("The 'protocol' MUST be `int`, but received {}".format( + type(protocol))) + + if protocol < 2 or protocol > 4: + raise ValueError("Expected 1<'protocol'<5, but received protocol={}". + format(protocol)) + + if not isinstance(obj, (core.LoDTensor, core.VarBase)): + raise NotImplementedError( + "Support 'paddle.Tensor' or 'paddle.core.LoDTensor', but received {}.". + format(type(obj))) + + def reudce_varbase(self): + data = self.numpy() + name = self.name + + return (tuple, ((name, data), )) + + def reduce_LoDTensor(self): + data = np.array(self) + + return (eval, ('data', {'data': data})) + + def add_dispatch_table(): + # This is not a good method, because the pickle module has been modified. + pickle.dispatch_table[core.VarBase] = reudce_varbase + pickle.dispatch_table[ParamBase] = reudce_varbase + pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor + + def pop_dispatch_table(): + pickle.dispatch_table.pop(core.VarBase) + pickle.dispatch_table.pop(core.LoDTensor) + pickle.dispatch_table.pop(ParamBase) + + # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' + if sys.platform == 'darwin' and sys.version_info.major == 3: + add_dispatch_table() + pickle_bytes = pickle.dumps(obj) + pop_dispatch_table() + + max_bytes = 2**30 + for i in range(0, len(pickle_bytes), max_bytes): + f.write(pickle_bytes[i:i + max_bytes]) + else: + if six.PY2: + add_dispatch_table() + pickle_bytes = pickle.dump(obj, f, protocol) + pop_dispatch_table() + else: + pickler = pickle.Pickler(f, protocol) + pickler.dispatch_table = copyreg.dispatch_table.copy() + + pickler.dispatch_table[core.VarBase] = reudce_varbase + pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor + pickler.dispatch_table[ParamBase] = reudce_varbase + + pickler.dump(obj) + + +def _use_legacy(obj): + # TODO(weixin):If `obj` is any object, the judgment condition should be more precise. + if not isinstance(obj, dict): + return False + return True + + +def _transformed_from_varbase(obj): + # In paddle2.1 version, VarBase is saved as tuple(tensor.name, tensor.numpy()). + # When executing paddle.load, use this function to determine whether to restore to VarBase/LoDTensor. + if isinstance(obj, tuple) and len(obj) == 2: + if six.PY2: + name_types = (str, unicode) + else: + name_types = str + if isinstance(obj[0], name_types) and isinstance(obj[1], np.ndarray): + return True + return False + + +def _transformed_from_lodtensor(obj): + # In paddle2.1 version, LoDTensor is saved as np.array(tensor). + # When executing paddle.load, use this function to determine whether to restore to VarBase/LoDTensor. + if isinstance(obj, np.ndarray): + return True + return False + + +def _to_LodTensor(ndarray): + if not isinstance(ndarray, np.ndarray): + raise TypeError( + 'Type of `ndarray` should be numpy.ndarray, but received {}.'. + format(type(ndarray))) + t = core.LoDTensor() + place = _current_expected_place() + t.set(ndarray, place) + return t + + +def _tuple_to_tensor(obj, return_numpy): + if return_numpy: + return obj[1] + if in_dygraph_mode(): + t = paddle.to_tensor(obj[1]) + # This function does modify the name of return value. + # Loading the same variable multiple times may cause the same name. + t.name = obj[0] + return t + else: + return _to_LodTensor(obj[1]) + + +def _ndarray_to_tensor(obj, return_numpy): + if return_numpy: + return obj + if in_dygraph_mode(): + return paddle.to_tensor(obj) + else: + return _to_LodTensor(obj) + + +def save(obj, path, protocol=2, **configs): ''' Save an object to the specified path. .. note:: - Now only supports save ``state_dict`` of Layer or Optimizer. + Now supports saving ``state_dict`` of Layer or Optimizer, Tensor. .. note:: Different from ``paddle.jit.save``, since the save result of ``paddle.save`` is a single file, @@ -219,8 +368,12 @@ def save(obj, path, pickle_protocol=2): obj(Object) : The object to be saved. path(str) : The path of the object to be saved. If saved in the current directory, the input path string will be used as the file name. - pickle_protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. + protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. Default: 2 + **configs(dict, optional): optional keyword arguments. The following options are currently supported: + use_binary_format(bool): When the saved object is static graph variable, you can specify ``use_binary_for_var``. + If True, save the file in the c++ binary format when saving a single static graph variable; otherwise, save it in pickle format. + Default: False Returns: None @@ -228,20 +381,91 @@ def save(obj, path, pickle_protocol=2): Examples: .. code-block:: python + # example 1: dynamic graph import paddle - emb = paddle.nn.Embedding(10, 10) layer_state_dict = emb.state_dict() + + # save state_dict of emb paddle.save(layer_state_dict, "emb.pdparams") - scheduler = paddle.optimizer.lr.NoamDecay( + + scheduler = paddle.optimizer.lr.NoamDecay( d_model=0.01, warmup_steps=100, verbose=True) adam = paddle.optimizer.Adam( learning_rate=scheduler, parameters=emb.parameters()) opt_state_dict = adam.state_dict() + + # save state_dict of optimizer paddle.save(opt_state_dict, "adam.pdopt") + # save weight of emb + paddle.save(emb.weight, "emb.weight.pdtensor") + + # example 2: static graph + import paddle + import paddle.static as static + + paddle.enable_static() + + # create network + x = paddle.static.data(name="x", shape=[None, 224], dtype='float32') + z = paddle.static.nn.fc(x, 10) + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + prog = paddle.static.default_main_program() + for var in prog.list_vars(): + if list(var.shape) == [224, 10]: + tensor = var.get_tensor() + break + + # save/load tensor + path_tensor = 'temp/tensor.pdtensor' + paddle.save(tensor, path_tensor) + + # save/load state_dict + path_state_dict = 'temp/model.pdparams' + paddle.save(prog.state_dict("param"), path_tensor) ''' + # 1. input check + filename = os.path.basename(path) + if filename == "": + raise ValueError("The input path MUST be format of dirname/filename " + "[dirname\\filename in Windows system], but received " + "filename is empty string.") + + # 2. save object + dirname = os.path.dirname(path) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname) + + config = _parse_save_config(configs) + + if not isinstance(config.use_binary_format, bool): + raise TypeError( + "Type of `use_binary_format` should be bool, but received {}.". + format(type(config.use_binary_format))) + + # `protocol` need to be used, `pickle_protocol` is a deprecated arg. + if config.pickle_protocol is not None: + protocol = config.pickle_protocol + warnings.warn( + "'pickle_protocol' is a deprecated argument. Please use 'protocol' instead." + ) + + if _use_legacy(obj): + if in_dygraph_mode(): + _legacy_save(obj, path, protocol) + else: + _legacy_static_save(obj, path, protocol) + else: + # save single variable + with open(path, 'wb') as f: + _pickle_save(obj, f, protocol) + +def _legacy_save(obj, path, protocol=2): # 1. input check if not isinstance(obj, dict): raise NotImplementedError( @@ -257,13 +481,13 @@ def save(obj, path, pickle_protocol=2): "[dirname\\filename in Windows system], but received " "filename is empty string.") - if not isinstance(pickle_protocol, int): + if not isinstance(protocol, int): raise ValueError("The 'protocol' MUST be `int`, but received {}".format( - type(pickle_protocol))) + type(protocol))) - if pickle_protocol < 2 or pickle_protocol > 4: + if protocol < 2 or protocol > 4: raise ValueError("Expected 1<'protocol'<5, but received protocol={}". - format(pickle_protocol)) + format(protocol)) # 2. save object dirname = os.path.dirname(path) @@ -274,19 +498,18 @@ def save(obj, path, pickle_protocol=2): if isinstance(obj, dict): saved_obj = _build_saved_state_dict(obj) - saved_obj = _unpack_saved_dict(saved_obj, pickle_protocol) + saved_obj = _unpack_saved_dict(saved_obj, protocol) - # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' - if sys.platform == 'darwin' and sys.version_info.major == 3 and ( - sys.version_info.minor == 5 or sys.version_info.minor == 6): - pickle_bytes = pickle.dumps(saved_obj, protocol=pickle_protocol) + # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' + if sys.platform == 'darwin' and sys.version_info.major == 3: + pickle_bytes = pickle.dumps(saved_obj, protocol=protocol) with open(path, 'wb') as f: max_bytes = 2**30 for i in range(0, len(pickle_bytes), max_bytes): f.write(pickle_bytes[i:i + max_bytes]) else: with open(path, 'wb') as f: - pickle.dump(saved_obj, f, protocol=pickle_protocol) + pickle.dump(saved_obj, f, protocol=protocol) def load(path, **configs): @@ -294,7 +517,7 @@ def load(path, **configs): Load an object can be used in paddle from specified path. .. note:: - Now only supports load ``state_dict`` of Layer or Optimizer. + Now supports load ``state_dict`` of Layer or Optimizer, Tensor. .. note:: In order to use the model parameters saved by paddle more efficiently, @@ -331,7 +554,9 @@ def load(path, **configs): ``save_inference_model`` save format. Default file name is :code:`__model__` . (2) params_filename (str): The persistable variables file name of the paddle 1.x ``save_inference_model`` save format. No default file name, save variables separately - by default. + by default. + (3) return_numpy(bool): If specified as True, return tensor as numpy.ndarray, otherwise return tensor as paddle.Tensor. + Default False. Returns: Object(Object): a target object can be used in paddle @@ -341,20 +566,115 @@ def load(path, **configs): import paddle + # example 1: dynamic graph + import paddle emb = paddle.nn.Embedding(10, 10) layer_state_dict = emb.state_dict() + + # save state_dict of emb paddle.save(layer_state_dict, "emb.pdparams") - scheduler = paddle.optimizer.lr.NoamDecay( + + scheduler = paddle.optimizer.lr.NoamDecay( d_model=0.01, warmup_steps=100, verbose=True) adam = paddle.optimizer.Adam( learning_rate=scheduler, parameters=emb.parameters()) opt_state_dict = adam.state_dict() + + # save state_dict of optimizer paddle.save(opt_state_dict, "adam.pdopt") + # save weight of emb + paddle.save(emb.weight, "emb.weight.pdtensor") + # load state_dict of emb load_layer_state_dict = paddle.load("emb.pdparams") + # load state_dict of optimizer load_opt_state_dict = paddle.load("adam.pdopt") + # load weight of emb + load_weight = paddle.load("emb.weight.pdtensor") + + + # example 2: static graph + import paddle + import paddle.static as static + + paddle.enable_static() + + # create network + x = paddle.static.data(name="x", shape=[None, 224], dtype='float32') + z = paddle.static.nn.fc(x, 10) + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + prog = paddle.static.default_main_program() + for var in prog.list_vars(): + if list(var.shape) == [224, 10]: + tensor = var.get_tensor() + break + + # save/load tensor + path_tensor = 'temp/tensor.pdtensor' + paddle.save(tensor, path_tensor) + load_tensor = paddle.load(path_tensor) + + # save/load state_dict + path_state_dict = 'temp/model.pdparams' + paddle.save(prog.state_dict("param"), path_tensor) + load_state_dict = paddle.load(path_tensor) + ''' + + if os.path.isfile(path): + config = _parse_load_config(configs) + with open(path, 'rb') as f: + # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' + if sys.platform == 'darwin' and sys.version_info.major == 3: + load_result = _pickle_loads_mac(path, f) + else: + load_result = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') + + # TODO(weixin):If `obj` is any object, the judgment condition should be more precise. + if isinstance(load_result, dict): + if isinstance(load_result, dict): + load_result = _pack_loaded_dict(load_result) + # paddle2.0: paddle.save/load + if "StructuredToParameterName@@" in load_result: + + for key in load_result["StructuredToParameterName@@"]: + load_result[key] = _ndarray_to_tensor( + load_result[key], config.return_numpy) + + if not config.keep_name_table and "StructuredToParameterName@@" in load_result: + del load_result["StructuredToParameterName@@"] + else: + # paddle2.1 static.save/load + for key in load_result: + load_result[key] = _ndarray_to_tensor( + load_result[key], config.return_numpy) + + else: + # TODO(weixin): support complex objects such as layer. + # If `obj` is any object, the judgment condition should be more precise. + if _transformed_from_lodtensor(load_result): + load_result = _ndarray_to_tensor(load_result, + config.return_numpy) + elif _transformed_from_varbase(load_result): + load_result = _tuple_to_tensor(load_result, + config.return_numpy) + else: + raise NotImplementedError( + 'Only support tensor and state_dict, but received {}.'. + format(type(load_result))) + + else: + load_result = _legacy_load(path, **configs) + + return load_result + + +def _legacy_load(path, **configs): load_result = None config = _parse_load_config(configs) diff --git a/tools/__pycache__/static_mode_white_list.cpython-37.pyc b/tools/__pycache__/static_mode_white_list.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..937267ff7180abf32341455e5d21fcac417782ae GIT binary patch literal 20217 zcmeI4cbqImwZ|`4P(+a+B0(jH2q@sa7ZDT{R8&+z1=EVQy>ohI+wGYix@UIpV$NAH z=bUrSIp>@)=bUqX-&5yQchAh?$M^nyazDR&w@+19SDiY!YPr##d+xD?|K73Nd3Ss9 zIa{`TPaoocuim<4%c(rnZ`ys!7TF@Z$X2<8>?*s-CFN4GyIflKkjuzrWly=B>?M24 zKC-V|UiOnK$o_IgIY17SgXCbjl3ZD?B8SMKa+n-0SCu2=YI3ApU5=8YCJW?JdkCw;Ch4NT=oIGBhAWxJh$&=+N@>F@6JYAk4&y;7$v*kJRTzQ^6UtSY49yj|WQ@054RyX8Id zUU{FqUp^ooln=>=RJ=z9wIn zZ^$?0Tk>uBj(k_XC*PMJ$PeX5@?-gl{8WA>KbK#~FXdPAYx#}*R(>bHmp{lKBs2`&Y82bTtWfXjf(f<3|Iz+PZ)un*W5 zTpsKPt^oE2R|E%u1HnPyU~naHWpEX62sjiR1`Y>T1xJ9Zfg{1y!BOC7a16KxxF$Fj z90!gE*8J6Ht;7=jTf!5EwmDo}$7n1VI14t9bKZ~?dnxF@(5xHq^DxG%ULxIcIR zcp!KXcrbVfcqn)ncsO_jcqDifcrD)yac=yybQb?yaK!uyb8P;yav1$ybin`yaBusya~J+ zyal`!ybZh^yaT)wybHV=ya&7&ybrt|d;ok9dZt{s{gA{tW&C{tEsE{to^D{t5nd;a2p2EBe0`{l^c0 zUBPbPlHgKccW`O22e^!S2L4s-iFddh|0?zZd-G>c{@n-ct3HE&75fFs75Hg?a7Azc zI1n5J4hB~OR|Z!Bhk!%DVc>9ZRd58j8aNVM9UKLY2FHMFfNO$d!ExYta4oP6YzI5Q z3EkhoCD4Uw*$8acK~+;cLGb`&fq+77jRdQfV+VdEQ1~pkbyoJfEAE~yMt9w zfFT%x5{$w5paM0RfGJo5>tH9?02hFJfO~>_fqR4dfct{`f%}68fCqvHfd_+!fQN#I zfro=ffJcHyfk%VKfD6H6!Q;T=!4tp}!IQv~!BfCf!PCIg!85=!!Lz`#!E?ZK!SlfL z!3)3(FWfqK_0kr8deJ#sj(RaM%+mgotezxkzn_n?q_>ieMp==Re@VQT$Y9Np?N#oIZj8Cr1ku;)-X?HRyqY^dNas$I>=Fz*2Q(7lwBH4JdqTLR*)tx_0t4xw&dm?!~Np<-&?j0J_UT<2Zy$z+< zW~aSLzLrk%a+Hil$)HLlb|qZQib<+njlHR)+OCAl+Gvsv_36spr%67{5rw2b-m#s( zPCTJ~dzkJ_(sAx$$8bvEecMcZUe`xs5RYS~h6t_=65W+;GT)-FdrO>xK&bPu95%j` zd^FG@5AqqR$g!+SnU;-&E*s6)Wf&H7dFyJrGN|&fUMChsuPD)qBHhTUrNgb~aEy3o zz10Nm$|uYhQDIX}gncce-bz^|gKU&l>4cw{(q-)JDjf~7B_F>UBt!JX+2UNU26b5l zE%CvsD%}X{NY*R;AcL7Ar*)lK&*nKFjI`aPcXMPi@-M zS6xhas9`3n8mn~CaEL8nt%OAw3&c#~h*$T~90VH7=)h-b*TEZIli( z*S_U+(pyPJWfeYug)s08Em_+hhvPf!LG4?*oFRfWCWiMVW5dJ3jGf&!=4-26&U;h9 zT(J~mU58njIr>_#VYc?=yh!Ps7D=OmX2NJN^QJ;=BG!<6d>hAD)wIEHwiHH|GVOH? z?{Yp)tBG#uFffG4*;Xu@XqpZ$mphYQk%zhaJe0!HgIrz3xChDjQ3ZkYt!G3ZWLyO{ zDg(?aBC)-o|2fid(WS#Ey|m6;*ZDEb)L!q=htGwLaLP(qWc6yWEI#1MbhMgeNSl7d zbi5a6Jeoj;Wh`$rJ8xzsF3xniR~AJEm-^$8at!rOi&{st{p^)Pn3kar@G@+^iI39r z0!if^tk!VJS5%Iqd2=MeQZT{gqU^24NtLW^=Y%~Nr32L5c#H>G+lEyQns1iYNfT7K zW?IoKwEgi-*5A}o`KH+40ptv!8m z&1iR*kfKfyHX1iNDoN*-Zer-$KGVQBE_Z>*3*vZ>#ga7D}khw z)H$v{7G-cCn4ChbFaeCKey4FY%UD^3MlM|`C{m{lrD$7ic9Jbm2P$f2Z{QSCM!eDb zTeJMi!s$2&wK)J@xvGu7*plcs%`?RbW_>Jd2(DB*lC^EhebX{&+s1WO=E1MBTW%fk zlz7z*lqtt>sRiQAWBk)N1OhK6Jj{k=r53Ntb4V;7}dAC1dtRS(1^PsaSL-160^1)GL9 z%wfm|)@i(w45!5;PuQT~D@{RUW2ByP$uyd$bm+dbo}ermjH$7X+q*GRw;HLw^)5Nq zMmRVLzHdC8EL|<|BU2lPau!(#0T7HiqiDWsu3PAzsfO3C8gm^6)ifEQ8Ycd8jUmp^ z)=kGIOb%h2uH5!)?sdd%T2(Au*}CZm)>kn7+G<82V@Wf?u<2Ph&yjqsGg>DB>K7%7 zd~vM}v-#HIpdQC<(m;^6`5Pofn>S2@7q#8&PCQ!a#gB~35KZrH+dNDow^bPegNfT_ z<;X03HXNhrn#bu>Ol;KxZzN<_mB$Y)VY}-=X33p%d?$OPfvyI_h_;)ltV^EaG2J;Q zJv2J=@V@EfQFVsynvrT+`p^g(M(!_`u3>A8qgUBvTE#~CAt-BErNOdVH1mH=PN$2s zXrv`s)MmcIsq~t8(Yjc8mTd%9*%tIu7}qgoyl6ZL4UPSSyhoUIKSbH431MBW-NxUh zV{*KWapqdq&oofh)Y&~{s>y(RP-~oyhFLX$XEuE!sw?(cXsVu$+70fIQGQJrgI<{h z3W+^a^`*_G`N}yHNV;h_NFT3dEhrkO?J!-GbM>3P0ngU zbeDd6VLHhRm;RPedtBKkAdk#o8rgbTts({Zb@d^xBKlk6;#@!sy}60qRx2- zCBXx8kNxS0OuZZtSE=*f9{t$@y^Y{De{;vhyx9#F-Rt*NMN?+s!S$FdA}&&!qN693 z(H6^x!%J0GOj$mKiBLt-8C##jOy+r85;bF1WNK>=tk?Xa4nvpWGTGhvAsSI$xSGT^ zH)N<#OkG__rnMF{)CuvD#g`1IwAis3Nr@HJfDpkjQkXh#HLKAM%pt5n3nEm+VjnY0 zJ~Dl@7~!=h{cy8bAHtduT*mT>o4NKlf|qW{{t~aR+1v-|WCi`2R9VN+4ALq2e9G)G zvSo-~%^7-`L9@tul?-G#vxLd?uW60RRHuirexXN5w=6hnOBiiP>zX-0FUN&Od%AXu zE|lh$s$Vh3nmzAns92Tjp_}bj<3=_0bKwi6mpUC1?*02UU8VwP~JO5Ph<7S8y6ub5NS8>k2VqNGm(trpulYE6w&TA8v2z zd=O3M@*yD?5v``={Tw+6Vj1#En{+-LYne!EU!hv!dfXuk8APw|jOx()(eyOmZ$=uV zqfLl}SSF;MK{Ua^XoX0Tq>%4TLENKQ?1XW~+i#LQ5G?x=9N9F6*nJ1U8@ zVbSumua9+6PU^^L?GVb}LyAL1l9%qjnmQdW(FJuBG?%O_1g2IG$RklicR8j)zFFQ_ zm;H&xxq&4!X}LR3Mj_v+$63!KkPsqhRXW!DEQXjTtF~J3nU9Fxg1d3;%b87zcBrIA zF$^fS*>szJlPV8ZOv+JOM_KT>S0-0?cQE0Qv6*rY^WVI@A+Yf+6ki#`7F^t-OY(~ zx`$Jew%yE1r=8>x?kPYYWwIV(V6VSxG@taKHw6Dp8^uZ8)|BKJxfym`hw5uO)?&VL!l$6LP$e(2 z(ZCpNavlzyQp%Fyj45(pE}Ukd+V@&35piE__m$!@?-4c>ibtdgo{N$&wx?L02<0(Xrw!C3l4M-qy`*?R3Iuk?!2X$QRA!&k74;l%W4sEW> zpY7AlYgN88^3}dM%&>ELIrTWmOFLDjB9iUw6;mqhAzjhb!bi69j~U`fA{9w;2^(e% z9A@=OJXW!-=g1=?YOe_Qb_RbRdiyl;5eY)Wmtk(6*$fnB4G)U(Tbc2_{+&_}F|Dd+ zOdXbZd4hT$nT3S(J~B+7&50BR*!b zP>sNG;%^eRUd$`aj4v4Xv(=TOLu2KjE*7T+0;7_;yOBOvW4i-kSx*dT=K3du_ zJ8oU48gWr6{dZHfs7gm4W>h4ay|LA_X;(9gu*$jmVkd0#@Q49JM?WMm+onMRVOL&g zncpOGobO~4>lI?LB6F!9m&L|_SU`*B&B(6)c6c51%Vb;$90|3D+YV&vvmGGN z!OULBF}?8Uk`oTZ371A0d8B3)LDzW0jGC73^N<^B9p7pWrhv4N#?vuAVvAZwGIJIX zdz6%dvg#tU1@Y+3`DF)jXobMbuZQ< zAnqiwqO{c>nl#<5#mex$QB^$S(Dhg$E=(eBWOD|h?aum*U`>@^erp^LVOY{|IUj{g z&V#YGO+$_%TMKf5MR%*)wvY9|KX^`$5IRj$yxn+{?|Ff*)JW-sl85H_ocrjxi%C<3 zNO)4?gkg;khcM6AjbrU)=%XsTuM?6-Vk=nzV|5zdn=`v`q^Ts!eS(Pd5mK+9_Ko$? zF1|m0(x(*DS=?Iixz0V#5U`qSO?P|L?mi<<<;jY*&ZG4VyKUkEfdhy-O#dT;Jx9?w z9JEU^encOh2`N1=GEePNOl+M`Li}eDUr75bo7Lb)Ptw$@*<_nbH@5un$h~(W8m19$ zTC`D{nSprs%900TK&{zy!=O4FXc!409&}m7057Czt!-(`JsN|-^x@&kvX%tH!HGpg zej4aAj-q57Ly!o9tp>mh%na5mRgRkI(Vm+L=a6|Btfn3b(A@;rosZtQi?{ra4xDFm zhI&eQ%dC85mgtfOhw#)_Hbd0T z7C2vgNkRMg>Q$)qN=}S7{nasfq8&QvAxKAk^;Of9I`r|J#P=bP=;3=b4r=JL3`N-_ z!((udvBcs!qrF(mJ-avuAB^LJ)y_<9s?5q5r(>FFA^p$SPFvK+TJM%%W|31yBlTylz+JWS=f2TfWa zvuC#!9EpToww3FlbXd7Ra?>!hn$Zr{Sp!kuhhUDvy_sffLV;WDO}kqh#TOX%JgAik zos#Bj4?frV7E2Ax9!#)3c28lG%DGK(I>L)2n+Y9P>4ISo zUJxs>ceSS!+yGT$z-rCnPc>UGQL*Vy$Hd!SYF6u|ewyCKP3)Y8zGeHE2WFZbUeEHu zik^#Uz7;RvnLMqt_j)&v%A&t(6h^iV8)Z?(I~^uQN?F^{Xuesq_I__WDWfrx%|fqm z(&(2NCM%ltI#-&@Z=TQcj`4IZ&u!>4HQZWn3z!F4pf8kGO=&v)EK#u&iFxwE+hzhg z-KZ1b<40Tx^R&9tR`xjJj-`MkIHtu;mTqgV#`Vv{CS%ujIwQ??3W-qk>{*YSS6c6c zH?HPHPPJZdQ!THZk+{<#vWJLcsOpVT&o{*{&}rb`B(X7qUC2ARq8%g3g*Losi0<7? zZ<535w5;Ck#q)pLCrP;?KvGw(PgZ=b;p{t_Q`^T-l%y=pMP;i^nmbz@nY*vHO?qm= z0n2#0bgkJOGRC}|Sj4WOgyP|(zH#z@_C}ag=VVl0HB%6(to6m_^>!SuIks;>b3T-f zD0hTHYoh{fTcs<4?p`-uVDIM@6{32QHQqnAaJu9TF%(_(X^UDGwh+VB3v;loWG6Ar5AMBV(Yg>w&) z8vSfa?iK`S!Ps4`I`Oa-+}|*l*}3N49@v-<^3UQhPl^)6CuZ6Y@w#|)e-<^-7+X2j&>!DUlvmJNsXvsDYC(nJxaxL9~xpkW- z3}-B1+3j%iEi7W@RI_~t4e!_tn}!!$=e{=6HaA6EmA}_c!)A^tv-{3%gY8JwUWIO( zOopO^ZRuunB7N@r?%>HM+ZQ}Vg*MmOq_Ne>xpmGr&%GV1?&xc^}uxM*dbirm@2N z#H$c(v!zaEz43+^X5%?`)rrHAR;@!=DXleh*h@wTwZPb4BVPR_Nc~rjOQ=h_aS4|i zImW2+cymI(jsYhYHmBwBmyv^5hqK9sQSthSCXRN$qZ7dlaa?WbzZ1*qaPFKIgY5CZ zvyD)@ZBrlCZ%M6PjnoL-$1s8&-NXPqPCPDAf5|ZZ-e7pu)1$?k8DUpF#W223_>H%j zz$FF=lO!l#@V78Np2X|&EZAz_x8Lg;7gNsF{X9*FZAIu}0}U&Rg~M`KU|JB)gqr&@ zamVc{TG#-Wl#<^Otf`Q+;e(ClBH@f~8S1})=h3gxo4c)??yP%;Z1Mn4x}LC&VwtD0 zh<+-H7rRP2Hj%Jbo#Qqoqaqf|Mr)cEdZ<6n$C=hX!=;dMiP3W5(#YLb(P6I%qHdrczIqu9!hChqB197hGrsy7cP~v| z!*2bD`R3sNo;~WT|GQLG-*6jGqVF}}m^Ytb)-c>i@aqSybxCslRYpAyj+VzN4_>wi zR)h@=mlis2c_*`aesd?;$t;X(9b-@VN9oox7KD^yeyTYo=)B#DfoINZ6GXvqYP)ZDBTbk>9pja9?Y;T+d zs%?-L7rNbzXWn3E%AH}{U2TsGF>$tYgV4?VLp+A+lKdCO1Ihf?CU<$QCh9wB7tX>9 zHeEYh37n@mqyGP_qj$OJfOF10?c6hNoSb>~O-@hFJ^h?>&p7L5$*s;f=iHZWsdnT4 zkQDy<>nX=q+-Eyp4|wb2n}YZ8w`Nz0?Dz@APsmmd*UuzWxwPEL`P^~ijf-|a1tRy$ xZm_rhFPF|^*sZ(#_s{=l4(z(ifB*cK1B)C`olt80xBV$ntbOs#|573c{s#-?Bv=3d literal 0 HcmV?d00001 diff --git a/tools/static_mode_white_list.pyc b/tools/static_mode_white_list.pyc deleted file mode 100644 index e9012c233595b6844f54e625972360f5aeeb0d3b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 21803 zcmeHPb+{~7k*{-t27=4X9}(-xW3>l!3_jA6x>K~W5G=XHx=AWaC5;e z1h*92N^onzZ3MR!+)i+N!5sv56x>O0XTeq3-$#k1*ZfD zf-40t7Q968Qo+jvFBiN*@Jhj}1g{pnM(|p}>jbYCyg~3r!J7nc7Q999R>9i@Zx_5n z@J_+I1n(BSNAO<3`vmV7d_eF)!G{DN7JNkTQNdRTJ|_4|!B+{sTJSZ3uN8cq;Ohk+ z7kq=@8wKAa_-4Vk2tFbBq~KG6PYXUH_^jYt1>Yw4oZ$01wSJAQNfQ1eq8Vqf}a%pl;EcYKO^{A!OsbPUhoToUljb3 z;FkrzBKTFouL*u#@Ed~P6#SOpw*|i=_+7#834UMj2ZBEo{E^^~1%D#=Q^B7J{#@`E zg1;2}mEf-hew&>UxNP@{Ey&E z2|&PzJmxU@u@Eun1TNQ~|4iD+20( zCZG*i2kZx&3^)~V5O8I{ivwN~@X~;n1-v}q6#=gdcvZlw16~vG+JM&uyguL!0dEX= zQ^1=8-V*TEfVTy_J>VSy?+kcXz`Fz96Y$=E_XWH^-~$044ERvMhXXzm@X>&;2>4jQ zR|b4lz*h%+O~BU%d|klT2Yfu>8v?#D;F|)zIpA9YJ`wQAfKLT{I^Z(_pAGodfNu-< zT)^i8zCGYO0=_fgy8^yD;ClkTH{c5a-xu)x0Y4D%g8@Gj@WTN=67Zt|KNj%g0Y4G& zlL0>!@Y4Z56Y#SEKNs-x0lyINivhnB@XGS@W%mv67Z)1e-`lP0e=zjmjQnj@Yex<6Y#eIe;4rg0sj#2j{#o{ z_@{t>4)~XVe+~GzfPWA8kAVLS_^*Kf4)~vdF9k?Mh=>tqL|h}{ni1EExOT*KBCZ>I z{eNb}^`h$kvuFz<&f?#r+JfuTB1GIM_7>sB5jWvanuds*Mch2%77@3MxK+fhBW@FM z+lbpm+&nh__l!6@;$9K=j<`?6eIw3^I5*_O5f6)ac*G+j9vShdh(|{}CgP%q$3{FZ z;_(p|M?4|oi4m7XJSpPRh$lxpCE}?OPm8!L;_`^6M?53qnGw&5cy`2dBAy%Zyol#V zydWZrcws~yu@fWJ4wyf)%>5wDMUL&O^+-W2iXh_^(%HR5d% zZ;yCK#5*J274hzf_e8um;(ZbCkN7~u2O~Zd@!^P%M0_;PJ|3SVG!XH!vp+^>puGjX zfwpMcET2xxWsyyG^X0ObXVt2`!GM}n%SqlU=#~Shm=}v;*`6$$BFll(;+pl2;%+s} zo0ju=c4bjlO*Sw0isKu|IS-K6dW3SBP3oy4Z#>MLSIFw3DbhJ^=FnPu|q3k7Wwe_+`8Tkw~y3FBu&AjB5#?&x~4_@UQ(^eUZRlBQ$Z)^ZFQd?wEiXZ?BTsZf;CEb3cB?v@70aT|TmFrQ zyq>L5=gV2q`d;-cTTt+N()mNxv!<$(*!jYBogbttW+!Xip9`D1$eV^|-3FDjrGiZK zI7&hL`Fx$G8)vGNE)#9%O`FK%M_x~K4QD$3Zb*Dyd!-yp+#T(D`xS{v4Mpx%8kPCH zyfQJUKg3Qx**lrna3bvFO*zSGvhXrr6fXEXc{|z7mQ|e&bn{`84RE%9e6%#T<;ykB z6gvf(yP<;PLMX^J>DosAF`D-GM`b7?PX{p_m9ykBu!+-qKCFrCJoOzXsV4_G{oI3QxQ6=GtV0f z5p;oy)TguM2((@_>rHh%*-yVs?2xy)zmxu&SIJ0HDbCfL zP%2#;jqrO#!K>HQ3hnI%86Kf2gFe7cjbdj#Qw}zbNV1Z7>7KdkVHQnEzFsBK=?ohu z-qrRxjLdZZ&IXT;Tuf;pk}s-UmhB(Y2Q$%zqgPE`mC3WpnfZ|(H!K#Af z{CoQ`Xyxd|@(PP$QPoKoCD*944pX{xBZC9_usJR1m2?kwn!&Cc5023DD)}$QvM8Df zRdcm0TzjKv^U3b`_dZRls|OVNJDbsWgu7=b^I29bb_%q2vTM+(tjnD>1>lg9Gux(~ zXE%tB4Q074W(Tg<4NHHs8J(%xROF~mE?3pMZf0<9#)?0cu6|urv+Fd~Y#|P)KUcfi zVm)umjL(|<1QSLaEY(YBO{)-)FM8LZX(^cM+$O6tT5_q5#enFU$2ssIq3Lk9$=t8j zZF~2DW`*iv>V?k~)&MT^ps~BJp z!hMfiz4wTZ#?fOw?L3C8r^T(4%V4iRzRh4+j@)onr8wp~#u;?9{9RPi(o@j)-DE>E zQ^jIM4WrSw9?pb7?ZrVx=eu^F%3;S{$cip7_ZK?Fm(}1xV-OV=Q=7#Yv!-v*dnMU~ znxyb(%f|H~O+js8O4E>2g6Tgkdz07R$0#stQy1;J9`GDWR@*OX4Nz3cnsv7zUGHyp z*P)%;YqqXOaP0O9-9Ynii-SZ_d$nm!5|6YjS1_U_oT+DG&KzztO(BlGXu7rA8u}Oc zigEkFs&H|h78;yrf~5Bbat+ViQRwjN<)Wx(B(2@7gRA%9=hFO4zI6WB#8B^OLLkYo z>B5xL;T8Pin@Juev++3xt8YzLxYVVR-7z0~%7}$mNIr|{kshcinzo#+t96s5$aRBz zhd~>ZoEOwt#`jbGNv=?@w@jJY&6{24iEc_w$A7`MtlfX!LS!el*IP5< zWS*pGG;}r@>!z(+(qMFFJ(j{6_tvM33Pv>J*@}}@y+>}MPowtG#c|psuh3tkt&?Uk zW}UA2q=UHoFolv@<_=8POQvwDCH+zLU5<|o_92ve!Ka;n;`H;!&bPZhQ&kW-57mb? zE5>*u5-66(?Gr%>Y1xo8wJzps(uc{cl$6vvWNFx)O|rNKGF7x)+R`&tT-dC%F2+%d z9meWcBo9sZb$J}ny`H7G4rvPKv~{C-Fg2^3YW_M`R`4)*{10X^(5(f`;_YzA1Ir|HQkyr zY_2v&7%+&NOr4}7)O65*F$2gXuY~9myB2@ zsYG++-8OEcRRDJ$+H;RMA&JE_GKKnN$zs89z>KNweBK@$xkbvqXgEw>+P;{1AmwsW zcc(NztYo(QRMtxpX}XO_f~kz@47i+fk7dOehSb|xPxWUrui9qx5;`VTdKOGB&a<*| z3vAh8T6kxpWzrIioRfcVNuY9o1IiUcWa`h_TFijG%hL0 z!D1jioCk1axygWQA*;!LbC6iAn#^583fPJNcKvPS5bB)V1&#Dh<>T9p2sUyH{H$OI9 z+SwslkE~f#Rad^U+4G#Xd%R1^J=y|hCr01}55k~}lgt8+ushRZ7i2!fo@z9N@jhR} zNHk|cYpMA=&0DZ8v&&5FoZRJ8yX?E?^zD`Vy!=K6-cnjbHd?oHQYBT!bDMEY0FCzh zNFjw(V>DCZ`tCXN8Dyo!rg2j>kgv32rNr<E~9f#`!^0PQ$|+lPItk9FI~6WMYv~*oeN>sOsBCSbAWKKp66SS$HcrN(+J*Yd+cz#`;8lJ*V zu_&l+OQxmF2kdixb)mj+acVMOvxJo5Lk;14mo7EfnTHXSv0qjqy|)p3Q8c?l(a~6l zMZkiAP)#rRRB~w296wmOWC%8<;G*o@4O5cTv2+EfjCVI&3E#HT=7BC-Z11qB=*G|j z`bNBFs==7^(-YKUmXw@Jp2t8we=^P6?Xm`|DGhE$r({GTMbcAC=~bt1g4UHrZFH@> z%=D=VLkgX2%18@?l-#<<S(0HHl27(7gCVk4Pyo%o z@k{r+FRrp4tO}%dCC?@Xnv6kge0$1t5)UIR)6H3baubdLww041xN9JBF-{pIqf7vwt}BOYp4%mslY0lwtbl>|m?;6lTtLQz;z_rs+M>ZN_#jM+XqfH3 z`gUW75+a43LtBM=r|O2X?#Cw{l3CE|PRC|;oTEmuTCrZb2X=T!(h$6HqrDv3CZpQo z_~7In9#7w~ca^%C9L?H!s&}{sK0Ix)|FvBt6I~QKF!{*Kq z)m985m`iFVqYM<`rx@xoJL=WL@uUvfRhImn>!Hl2x3wH^I@X9AtQe@H{LK!M?pK&| z$WWZ*g%>;JGDX|&wXO`U-q*a?Ph3WG$mJUb;3#oaF1WOSP5%!jRVDc7S z6-q(N@KjSebVao!Y3x~(yZVsRVT*=xlALF!4LyoUL7kw=#|c4b@6hfRbH0Th%;&H; zp85+EYVm23CGM#Kv#KFekvhuXwlAzJ(DVlFc5*hH)GVV> z^qwPUFPgMLm2lQQ2J@sAxulm!E>k(xmYnpPPJ8c7nwQec3WJ=oO
    uoSy@*$VX) zwLq!AxmE$gPL}RO& zi;g2U=uCW4^}ZV^x1+h}R9Q41vM<5=Wd`0f+LALFfppm*9?MjBvs;%GOs%B%%FAil z*;DF!t%q@v2}W;ewY7NYZob#>SZqUFXsh5xlBUbs_^X%#9!5_q+Bx%Tor-;%qglpL zs)bXw0;fiMGk8wFbeGOxz9WSJe(4RkPO+%v^ob@N=Co0v(PqHGo%aP+ObB$vTdfa{ zl!q>0BH5&2eE{k2mY^KxFS|B%(e3s(Y%1ZXF5Gx;cHhjYVap`W#?A`e&RzL6wvk=bX8Nh6UhLYJjwlk@ zxz|hjHHr2k(6s8F!>GxchbEUtRVHHxqUbC*n{~hYOl(Ef+<45#7JF4rB!}6swZ?KC z@7c}PE7p?w0*Y!`^<6l}s87YL0Z6q0DHHL{l$THTmK6v%cRxp~0E< z3|qO5F>0eA(P;aS*~deuLlpuO)6vi_!nBgzHs;V8rn5&QnNp(Dvm}u{$1*}!6SlDo z`JyZt-y_F$=@*T>*;$`;2eY=9(R!QXQYh=v$WV#>qMD&6+3Ybp8P%wV;KnIUKVm^n zruM`tTO~)}VV(tt_$*Nyp@VgKoa6(JBAd zd+0jpL0|Tybh=Teg3dPOlZpb3`>cnL&~@EtZpFhoE36hvro~d#U@S_GuJw*YnL;+y zv}A`Ltq>hdH?YH3Of1g!I8zQbQETyT77g6T21jC&bvg8_s4_+6xH#~-qyUHQxsv8+#cn4P|R@MBU*Czaf-+2vKw8D5OmCjbToycW20JwzVa9@pW{;mW`0O;%q$!!&tTvu$?kz z=WbV?@-e1-ZH>}; z38H;nNDiya5Yo;?*50f_Xb&o(MK~B2W>ia&EvO7`>cB(!DzqRu=dh5`3+kUp6%C1A@ z)qT+6IHP^-pL;mgm}4-_jxzl+ptg=R*wAj|BPm9sqxKigrPK$h2(PM7jDw(9jHmrCi}(Jb}C^~?HE z!=&q))E~{FS-^g z`8w=VuVx-7rM$fdaiorg_FezA#tFVy2kUBxR*o>W*kO75^znKZTBSf^6c!%thf0jl zM@L37Hf?V@j3bAm&5~p%OHCs;w6y58JGX9QmF|1q)qKsS6tDd1`p$u;bzLk~nHo$; z9h0e4rL>kdRcp#8T`OvOC+D(?Uz#_}H@P=r%l5qF@8Lx6ztwh+yhB-kJqwH{JxS0p z68U>k5?O4f7 zf=*^vcgOnafxrAPd-nf|n05C5IY`$@#(|mu>GS+U_TYR|%HuvBS7{l)O0QD`#GvlF z?2JotQd)fd)BrS^*pD;K+c({1r+6Zx=Ew>$lcp#eKcv?bXjawANt?@zT=VT^`tQ`! zcw447Y5iE^^0CaQith2|ETnICIG0yTC7#Vi``-S3lI>-xJdHM$)Pqy9H(~7Jy^2Xd zi<=kT{FM;)rhPJIJjtlrg$X#XEfNl-JvM|@|H z?s)0hL0L^5mCG)F*{XV; z_9-;ySrOyBr|@;=#d$5W&;)lS4$V7eOcrX!nf2gvR|l!n>_WoXM*1)%Z*$2tMa~FU U|8q43uBO1%6u6oK|6dCHFWA)*QUCw| -- GitLab