From 1476e1f99880e3148bdcc3c8c344679d09c6cf50 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Mon, 30 Nov 2020 10:20:30 +0800 Subject: [PATCH] save model after jit.load (#28748) * Changed a variable name error * Add comments * Move member functions of TranslatedLayer out of function * edit code according to review * Edit input argument of '_run_static_graph' * reset due to Segmentation fault * rename variables when stitching graph * modify code according CI * Add comments to '__i_m_p_l__' * remove blanks befor 'Get...' * edit code according to review * Add a comment to '_execution_method_creator' * Edit a comment to '_execution_method_creator' --- .../dygraph_to_static/function_spec.py | 7 + .../fluid/dygraph/dygraph_to_static/utils.py | 19 +- python/paddle/fluid/dygraph/io.py | 469 ++++++++++++++---- .../tests/unittests/test_jit_save_load.py | 89 ++++ 4 files changed, 478 insertions(+), 106 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py index 34fb168495..205766e461 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -25,8 +25,10 @@ from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs +from paddle.fluid.dygraph.dygraph_to_static.utils import parse_varargs_name from paddle.fluid.dygraph.dygraph_to_static.utils import type_name from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code +from paddle.fluid.dygraph.io import TranslatedLayer class FunctionSpec(object): @@ -45,6 +47,11 @@ class FunctionSpec(object): # parse full argument names list. self._arg_names, self._default_kwargs = parse_arg_and_kwargs(function) + # parse *args + self.varargs_name = parse_varargs_name(function) + if self.varargs_name is not None and isinstance(function.__self__, + TranslatedLayer): + self._arg_names += function.__self__._input_args_names def unified_args_and_kwargs(self, args, kwargs): """ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index cdb4b8e52d..db3024821f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -113,6 +113,15 @@ def parse_arg_and_kwargs(function): return arg_names, default_kwargs +def parse_varargs_name(function): + """ + Returns varargs name string of function. e.g: 'input' from `foo(x, *input)` + """ + fullargspec = getfullargspec(function) + varargs = fullargspec.varargs + return varargs + + def type_name(v): return type(v).__name__ @@ -478,11 +487,17 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): else: module = SourceFileLoader(module_name, f.name).load_module() func_name = dyfunc.__name__ - if not hasattr(module, func_name): + # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained + # through 'func_name'. So set the special function name '__i_m_p_l__'. + if hasattr(module, '__i_m_p_l__'): + callable_func = getattr(module, '__i_m_p_l__') + callable_func.__name__ = func_name + elif hasattr(module, func_name): + callable_func = getattr(module, func_name) + else: raise ValueError( 'Function: %s doesn\'t exist in the Module transformed from AST.' % func_name) - callable_func = getattr(module, func_name) # After transform dygraph function into callable_func saved in tmp file, # it lost the global variables from imported statements or defined in source file. # Recovers the necessary variables by `__globals__`. diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 05d2b0bf1e..ecf560499e 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -28,6 +28,7 @@ from paddle.fluid import unique_name from paddle.fluid.dygraph import layers from paddle.fluid.layers import nn from paddle.fluid.dygraph.base import switch_to_static_graph +from paddle.fluid.framework import in_dygraph_mode __all__ = ['TranslatedLayer'] @@ -163,10 +164,17 @@ def _get_loaded_var_new_old(program_desc, all_new_old_dict_all): return new_old_dict -def _rename_var_program_desc(program_desc): +def _rename_var_program_desc(program_desc, include=None, exclude=None): """ Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication - e.g. x ==> x_0, x_0 ==> x_1 + e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0. + If 'include' is not `None`,variables that are not in include are not renamed. + If 'exclude' is not `None`,variables that are in exclude will are not renamed. + + Args: + program_desc(ProgramDesc):the variables in it will be modified. + include(List):list of names of variables. + exclude(List):list of names of variables. """ dict_rename_var_old_new = dict() dict_rename_var_new_old = dict() @@ -175,25 +183,26 @@ def _rename_var_program_desc(program_desc): cur_block = program_desc.block(b_idx) for var in cur_block.all_vars(): old_names.append(var.name()) - persistable_vars = _get_persistable_vars(program_desc) for b_idx in six.moves.range(program_desc.num_blocks()): cur_block = program_desc.block(b_idx) for var_idx, var in enumerate(cur_block.all_vars()): - if var not in persistable_vars: - continue name_old = var.name() - while True: + should_rename = (include is None or name_old in include) and ( + exclude is None or name_old not in exclude) + if should_rename: temp_name = name_old.split('_') if len(temp_name) > 1 and temp_name[-1].isnumeric(): temp_name = "_".join(temp_name[:-1]) else: - temp_name = "_".join(temp_name) - - name_new = _generate_unique_var_name_sync_with_main_program( - temp_name) - if name_new not in old_names[:var_idx] + old_names[var_idx + - 1:]: - break + temp_name = name_old + while True: + name_new = _generate_unique_var_name_sync_with_main_program( + temp_name) + if name_new not in old_names[:var_idx] + old_names[var_idx + + 1:]: + break + else: + name_new = name_old if name_old != name_new: cur_block._rename_var( cpt.to_bytes(name_old), cpt.to_bytes(name_new)) @@ -300,8 +309,10 @@ class _ProgramHolder(object): return self._inner_scope def _preprocess(self, program_desc): - # rename variables of 'program_desc' - rename_new_old_dict, _ = _rename_var_program_desc(program_desc) + # rename persistable variables of 'program_desc' + list_persistable_var = _get_persistable_var_names(program_desc) + rename_new_old_dict, _ = _rename_var_program_desc(program_desc, + list_persistable_var) # 1. Prune original program # remove feed, fetch and scale-1 op, remove op_callstack attr ops_to_remove = [] @@ -645,6 +656,327 @@ def _construct_params_and_buffers(model_path, return var_dict +def _run_dygraph(instance, input, program_holder): + + # 1. prepare inputs, outputs, attrs + input_vars = [] + for i, value in enumerate(input): + if not isinstance(value, (np.ndarray, core.VarBase)): + raise TypeError( + "The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s." + % type(value)) + # NOTE: In order to unify the API, firstly convert the input to VarBase + if isinstance(value, np.ndarray): + var = core.VarBase( + value=value, + name=program_holder.input_descs[i].name(), + persistable=False, + place=framework._current_expected_place(), + zero_copy=True) + else: + var = value + # NOTE: we changed var name here, + # but it may be an important name set by user + var.name = program_holder.input_descs[i].name() + input_vars.append(var) + if instance._input_args_names is None: + instance._input_args_names = [ + ins.name() for ins in program_holder.input_descs + ] + + persistable_vars = [] + for var_name in program_holder.persistable_names: + dy_var_name = instance._persistable_var_name_dict[var_name] + if dy_var_name in instance._parameters: + persistable_vars.append(instance._parameters[dy_var_name]) + elif dy_var_name in instance._buffers: + persistable_vars.append(instance._buffers[dy_var_name]) + else: + raise ValueError( + "The persistable variable %s does not exist in current TranslatedLayer." + % var_name) + + output_vars = [] + for var_desc in program_holder.output_descs: + var = core.VarBase(var_desc.dtype(), + var_desc.shape(), + var_desc.name(), var_desc.type(), False) + output_vars.append(var) + + # hold forward variables + tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], + "program_out_scope", + core.VarDesc.VarType.STEP_SCOPES, True) + tmp_scope_vec.value().set_scope(program_holder.scope) + + # 2. run program by op + trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program + end_op_index = program_holder.infer_program.block(0).op_size() + framework._dygraph_tracer().trace_op( + type='run_program', + inputs={'X': input_vars, + 'Params': persistable_vars}, + outputs={'Out': output_vars, + 'OutScope': tmp_scope_vec}, + attrs={ + 'global_block': trace_program.block(0), + 'start_op_index': 0, + 'end_op_index': end_op_index, + 'is_test': instance._is_test + }) + # NOTE: [ why need set param's gradient type here ] + # if user set sparse gradient mode, the param's gradient + # will be SelectedRows, not LoDTensor. But tracer will just + # set param grad VarBase by forward VarBase(LoDTensor) + # If we don't change grad_var type here, RunProgramOp need + # transform SelectedRows to LoDTensor forcibly, it may not + # be user wanted result. + for persistable_var in persistable_vars: + grad_var_name = var.name + core.grad_var_suffix() + grad_var = trace_program.block(0).find_var(cpt.to_bytes(grad_var_name)) + # NOTE: cannot find var desc maybe not problem, + # such as in batch_norm + if grad_var is None: + continue + persistable_var._set_grad_type(grad_var.type()) + + # 3. prepare output, keep same form with inputs + outs = output_vars + if len(output_vars) == 1: + outs = output_vars[0] + return outs + + +def _run_static_graph(input, program_holder, trace_program): + main_program = framework.default_main_program() + param_var_names = _get_persistable_var_names(trace_program) + _, dict_rename_var_old_new = _rename_var_program_desc( + trace_program, exclude=param_var_names) + trace_program.flush() + output_names = [var.name() for var in program_holder.output_descs] + # append blocks from 'trace_program' + _append_block(main_program, trace_program, program_holder, input, + dict_rename_var_old_new) + main_program._sync_with_cpp() + outs = _get_output_from_program(main_program, program_holder, + dict_rename_var_old_new) + if len(outs) == 1: + outs = outs[0] + return outs + + +def _collect_current_and_parent_var(program, block_idx): + ''' + Get variables in current block and its parent block. + + Args: + program(Program): The program containing the current block. + block_idx(int): index of current block. + + Returns: + List: list of variables. + ''' + vars = [] + if block_idx < 0: + return vars + for var in program.block(block_idx).vars: + vars.append(var) + parent_idx = program.block(block_idx).parent_idx + if parent_idx > -1: + vars += _collect_current_and_parent_var(program, parent_idx) + return vars + + +def _append_block(dest_program, + src_program_desc, + program_holder, + input_variables, + dict_rename_var_old_new=None): + ''' + Append Variables and Operators in 'src_program_desc' to dest_program. + + Args: + dest_program(Program): Variables and Operators are appended to it. + src_program_desc(ProgramDesc): Variables in it will be appended to 'dest_program'. + program_holder(_ProgramHolder): program_holder of TranslatedLayer + input_variables(list): list of input variables + dict_rename_var_old_new(None|dict): When using '_rename_var_program_desc', + use it to map the name of the variable before it was modified and the new name. + ''' + + origin_block_idx = dest_program.current_block_idx + param_var_names = _collect_current_and_parent_var(dest_program, + origin_block_idx) + append_var_from_block_desc_static( + dest_program.block(origin_block_idx), + src_program_desc.block(0), + exclude=param_var_names) + + name_inp_desc = [inp.name() for inp in program_holder.input_descs] + input_names = [inp.name for inp in input_variables] + if len(name_inp_desc) != len(input_names): + raise ValueError( + "The number of input is invalid, expected {}, but received {}.". + format(len(name_inp_desc), len(input_names))) + for i, out_name in enumerate(name_inp_desc): + if dict_rename_var_old_new: + out_name = dict_rename_var_old_new[out_name] + dest_program.block(origin_block_idx).append_op( + type='assign', + inputs={'X': [input_names[i]]}, + outputs={'Out': [out_name]}) + + append_ops = append_op_from_block_desc_static( + dest_program.block(origin_block_idx), src_program_desc.block(0)) + dest_program._sync_with_cpp() + + offset_block_idx = dest_program.num_blocks - 1 + + if src_program_desc.num_blocks() > 1: + for src_block_idx in range(1, src_program_desc.num_blocks()): + src_block = src_program_desc.block(src_block_idx) + src_parent_idx = src_block.parent + if src_parent_idx > 0: + parent_idx = offset_block_idx + parent_idx + else: + parent_idx = origin_block_idx + dest_block = dest_program._create_block(parent_idx=parent_idx) + append_var_from_block_desc_static( + dest_block, src_block, exclude=param_var_names) + append_ops += append_op_from_block_desc_static(dest_block, + src_block) + + dest_program._sync_with_cpp() + for op in append_ops: + if op.has_attr('sub_block'): + sub = op.attr('sub_block') + if isinstance(sub, framework.core.BlockDesc): + origin_id = sub.id + if isinstance(sub, framework.Block): + origin_id = sub.idx + op._set_attr('sub_block', + dest_program.block(offset_block_idx + origin_id)) + dest_program._sync_with_cpp() + dest_program.current_block_idx = origin_block_idx + + +def _get_output_from_program(program, + program_holder, + dict_rename_var_old_new=None): + """ + Get output name of 'program' according to program_holder + """ + outs = list() + for var in program_holder.output_descs: + for idx in range(program.num_blocks): + vars = program.block(idx).vars + var_name = var.name() + if dict_rename_var_old_new: + var_name = dict_rename_var_old_new[var_name] + if var_name in vars: + out = vars[var_name] + if out not in outs: + outs.append(out) + return outs + + +def append_op_from_block_desc_static(block, src_block_desc): + """ + Append Operators of 'src_block_desc' to current block. + + Args: + block(Block): append OP of 'src_block_desc' to it. + src_block_desc(BlockDesc): append var of 'src_block_desc' + + Returns: + List: list of the OP that are append to current block. + """ + ops = [] + for i in range(src_block_desc.op_size()): + ops.append(append_op_from_desc_static(block, src_block_desc.op(i))) + return ops + + +def append_op_from_desc_static(block, op_desc): + """ + Append Operators to 'block' according to 'op_desc'. + + Args: + block(Block): append OP of 'src_block_desc' to it. + op_desc(OpDesc): create OP according to it. + + Returns: + Operator: OP appended to 'block'. + """ + op_type = op_desc.type() + op_append = block.desc.append_op() + op_append.copy_from(op_desc) + op = framework.Operator( + block=block, + desc=op_append, + type=op_type, + inputs=None, + outputs=None, + attrs=None) + block.ops.append(op) + return op + + +def append_var_from_block_desc_static(block, + src_block_desc, + include=None, + exclude=None): + """ + Append Variables of 'src_block_desc' to current block. + If 'include' is not `None`,variables that are not in include are not append. + If 'exclude' is not `None`,variables that are in exclude will are not append. + + Args: + block(Block): append Variables of 'src_block_desc' to it. + src_block_desc(BlockDesc): append var of 'src_block_desc' + include(List):list of names of variables + exclude(List):list of names of variables + + Returns: + List: list of the variables that are append to current block. + """ + vars_append = [] + for var_desc in src_block_desc.all_vars(): + var_desc_name = var_desc.name() + should_append = (include is None or var_desc_name in include) and ( + exclude is None or var_desc_name not in exclude) + if not block.has_var(var_desc_name) and should_append: + var_type = var_desc.type() + if var_type in [ + core.VarDesc.VarType.SELECTED_ROWS, + core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.LOD_TENSOR_ARRAY + ]: + data_type = var_desc.dtype() + var_shape = var_desc.shape() + else: + data_type = None + var_shape = None + if var_type in [ + core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.LOD_TENSOR_ARRAY + ]: + lod_level = var_desc.lod_level() + else: + lod_level = None + + vars_append.append( + block.create_var( + name=var_desc.name(), + dtype=data_type, + type=var_type, + shape=var_shape, + lod_level=lod_level, + persistable=var_desc.persistable(), + set_need_check_feed=var_desc.need_check_feed())) + return vars_append + + class TranslatedLayer(layers.Layer): """ TranslatedLayer is a ``paddle.nn.Layer`` for holding the model @@ -780,6 +1112,7 @@ class TranslatedLayer(layers.Layer): ) self._is_test = True + self._input_args_names = None @staticmethod @framework.dygraph_only @@ -817,95 +1150,23 @@ class TranslatedLayer(layers.Layer): @staticmethod def _execution_method_creator(method_name, program_holder): - def __impl__(self, *input): - # 1. prepare inputs, outputs, attrs - input_vars = [] - for i, value in enumerate(input): - if not isinstance(value, (np.ndarray, core.VarBase)): - raise TypeError( - "The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s." - % type(value)) - # NOTE: In order to unify the API, firstly convert the input to VarBase - if isinstance(value, np.ndarray): - var = core.VarBase( - value=value, - name=program_holder.input_descs[i].name(), - persistable=False, - place=framework._current_expected_place(), - zero_copy=True) - else: - var = value - # NOTE: we changed var name here, - # but it may be an important name set by user - var.name = program_holder.input_descs[i].name() - input_vars.append(var) - - persistable_vars = [] - for var_name in program_holder.persistable_names: - dy_var_name = self._persistable_var_name_dict[var_name] - if dy_var_name in self._parameters: - persistable_vars.append(self._parameters[dy_var_name]) - elif dy_var_name in self._buffers: - persistable_vars.append(self._buffers[dy_var_name]) - else: - raise ValueError( - "The persistable variable %s is not exists in current TranslatedLayer." - % var_name) - - output_vars = [] - for var_desc in program_holder.output_descs: - var = core.VarBase(var_desc.dtype(), - var_desc.shape(), - var_desc.name(), var_desc.type(), False) - output_vars.append(var) - - # hold forward variables - tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], - "program_out_scope", - core.VarDesc.VarType.STEP_SCOPES, True) - tmp_scope_vec.value().set_scope(program_holder.scope) - - # 2. run program by op - trace_program = program_holder.infer_program if self._is_test else program_holder.train_program - end_op_index = program_holder.infer_program.block(0).op_size() - framework._dygraph_tracer().trace_op( - type='run_program', - inputs={'X': input_vars, - 'Params': persistable_vars}, - outputs={'Out': output_vars, - 'OutScope': tmp_scope_vec}, - attrs={ - 'global_block': trace_program.block(0), - 'start_op_index': 0, - 'end_op_index': end_op_index, - 'is_test': self._is_test - }) - - # NOTE: [ why need set param's gradient type here ] - # if user set sparse gradient mode, the param's gradient - # will be SelectedRows, not LoDTensor. But tracer will just - # set param grad VarBase by forward VarBase(LoDTensor) - # If we don't change grad_var type here, RunProgramOp need - # transform SelectedRows to LoDTensor forcibly, it may not - # be user wanted result. - for persistable_var in persistable_vars: - grad_var_name = var.name + core.grad_var_suffix() - grad_var = trace_program.block(0).find_var( - cpt.to_bytes(grad_var_name)) - # NOTE: cannot find var desc maybe not problem, - # such as in batch_norm - if grad_var is None: - continue - persistable_var._set_grad_type(grad_var.type()) - - # 3. prepare output, keep same form with inputs - outs = output_vars - if len(output_vars) == 1: - outs = output_vars[0] - return outs - - __impl__.__name__ = method_name - return __impl__ + def __i_m_p_l__(self, *input): + program_holder = self._program_holder_dict[__i_m_p_l__.__name__] + # When using jit.save, it runs in static graph mode. + # Run in dynamic graph mode when the model is inferring. + if in_dygraph_mode(): + return _run_dygraph(self, input, program_holder) + else: + # NOTE(weixin): [ why not use 'program_holder.infer_program' directly? ] + # When use '_run_static_graph(input, program_holder, program_holder.infer_program)', + # because '_run_static_graph' modifies 'ProgramDesc', 'OpDesc.op_size()' will return a very large wrong number. + # A Segmentation fault error may occur if used 'p=ProgramDesc(program_holder.infer_program)'. + p = framework.Program._construct_from_desc( + core.ProgramDesc(program_holder.infer_program)) + return _run_static_graph(input, program_holder, p.desc) + + __i_m_p_l__.__name__ = method_name + return __i_m_p_l__ def train(self): self._is_test = False diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 258136c3cf..3e0b6a83b4 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -25,6 +25,7 @@ from paddle.fluid.layers.utils import flatten from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import declarative, ProgramTranslator from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX +from paddle.fluid import unique_name BATCH_SIZE = 32 BATCH_NUM = 10 @@ -863,6 +864,94 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase): layer, model_path, input_spec=[InputSpec(shape=[None, 784])]) +class LayerSaved(paddle.nn.Layer): + def __init__(self, in_size, out_size): + super(LayerSaved, self).__init__() + self.hidden = 100 + self._linear_0 = Linear(in_size, self.hidden) + self._linear_1_0 = Linear(self.hidden, self.hidden) + self._linear_1_1 = Linear(self.hidden, self.hidden) + self._linear_2 = Linear(self.hidden, out_size) + self._scale = paddle.to_tensor(9.9) + + @paddle.jit.to_static + def forward(self, x): + y = self._linear_0(x) + # Multiple blocks + if x.shape[0] == 1: + y = self._linear_1_0(y) + else: + y += self._linear_1_1(y + self._scale) + return self._linear_2(y) + + +class LayerLoadFinetune(paddle.nn.Layer): + def __init__(self, in_size, out_size, load_path): + super(LayerLoadFinetune, self).__init__() + # Test duplicate name + self._linear_0 = Linear(in_size, in_size) + self._linear_1_0 = Linear(out_size, in_size) + self._linear_1_1 = Linear(out_size, in_size) + self._linear_2 = Linear(out_size, out_size) + self._scale = paddle.to_tensor(9.9) + + # Load multiple times + self._load_l1 = paddle.jit.load(load_path) + self._load_l2 = paddle.jit.load(load_path) + + @paddle.jit.to_static + def forward(self, x): + y = self._linear_0(x) + y = self._load_l1(y) + # Multiple blocks + if x.shape[0] == 1: + y = self._linear_1_0(y) + y = self._load_l1(y) + else: + y += self._linear_1_1(x + self._scale) + y = self._load_l2(y) + y = self._linear_1_0(y) + y = self._load_l1(y) + y = self._linear_1_0(y) + # Use the same layer multiple times. + y = self._load_l1(y) + return y + + +class TestJitSaveLoadFinetuneLoad(unittest.TestCase): + def setUp(self): + # enable dygraph mode + paddle.disable_static() + + def test_save_load_finetune_load(self): + model_path = "test_jit_save_load_finetune_load/model" + IMAGE_SIZE = 224 + inps0 = paddle.randn([1, IMAGE_SIZE]) + inps1 = paddle.randn([2, IMAGE_SIZE]) + # Use new namespace + with unique_name.guard(): + layer_save = LayerSaved(IMAGE_SIZE, IMAGE_SIZE) + layer_save(inps0) + #save + paddle.jit.save(layer_save, model_path) + #load + with unique_name.guard(): + layer_load = LayerLoadFinetune(IMAGE_SIZE, IMAGE_SIZE, model_path) + #train + train(layer_load, input_size=IMAGE_SIZE) + result_00 = layer_load(inps0) + result_01 = layer_load(inps1) + #save + paddle.jit.save(layer_load, model_path) + #load + layer_finetune = paddle.jit.load(model_path) + result_10 = layer_finetune(inps0) + result_11 = layer_finetune(inps1) + + self.assertTrue(float((result_00 - result_10).abs().max()) < 1e-5) + self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5) + + class TestJitSaveLoadDataParallel(unittest.TestCase): def verify_inference_correctness(self, layer, path): layer.eval() -- GitLab