未验证 提交 1476e1f9 编写于 作者: W WeiXin 提交者: GitHub

save model after jit.load (#28748)

* Changed a variable name error

* Add comments

* Move member functions of TranslatedLayer out of function

* edit code according to review

* Edit input argument of '_run_static_graph'

* reset due to Segmentation fault

* rename variables when stitching graph

* modify code according CI

* Add comments to '__i_m_p_l__'

* remove blanks befor 'Get...'

* edit code according to review

* Add a comment to '_execution_method_creator'

* Edit a comment to '_execution_method_creator'
上级 0239f796
...@@ -25,8 +25,10 @@ from paddle.fluid.layers.utils import pack_sequence_as ...@@ -25,8 +25,10 @@ from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_varargs_name
from paddle.fluid.dygraph.dygraph_to_static.utils import type_name from paddle.fluid.dygraph.dygraph_to_static.utils import type_name
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.io import TranslatedLayer
class FunctionSpec(object): class FunctionSpec(object):
...@@ -45,6 +47,11 @@ class FunctionSpec(object): ...@@ -45,6 +47,11 @@ class FunctionSpec(object):
# parse full argument names list. # parse full argument names list.
self._arg_names, self._default_kwargs = parse_arg_and_kwargs(function) self._arg_names, self._default_kwargs = parse_arg_and_kwargs(function)
# parse *args
self.varargs_name = parse_varargs_name(function)
if self.varargs_name is not None and isinstance(function.__self__,
TranslatedLayer):
self._arg_names += function.__self__._input_args_names
def unified_args_and_kwargs(self, args, kwargs): def unified_args_and_kwargs(self, args, kwargs):
""" """
......
...@@ -113,6 +113,15 @@ def parse_arg_and_kwargs(function): ...@@ -113,6 +113,15 @@ def parse_arg_and_kwargs(function):
return arg_names, default_kwargs return arg_names, default_kwargs
def parse_varargs_name(function):
"""
Returns varargs name string of function. e.g: 'input' from `foo(x, *input)`
"""
fullargspec = getfullargspec(function)
varargs = fullargspec.varargs
return varargs
def type_name(v): def type_name(v):
return type(v).__name__ return type(v).__name__
...@@ -478,11 +487,17 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -478,11 +487,17 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
else: else:
module = SourceFileLoader(module_name, f.name).load_module() module = SourceFileLoader(module_name, f.name).load_module()
func_name = dyfunc.__name__ func_name = dyfunc.__name__
if not hasattr(module, func_name): # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# through 'func_name'. So set the special function name '__i_m_p_l__'.
if hasattr(module, '__i_m_p_l__'):
callable_func = getattr(module, '__i_m_p_l__')
callable_func.__name__ = func_name
elif hasattr(module, func_name):
callable_func = getattr(module, func_name)
else:
raise ValueError( raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' % 'Function: %s doesn\'t exist in the Module transformed from AST.' %
func_name) func_name)
callable_func = getattr(module, func_name)
# After transform dygraph function into callable_func saved in tmp file, # After transform dygraph function into callable_func saved in tmp file,
# it lost the global variables from imported statements or defined in source file. # it lost the global variables from imported statements or defined in source file.
# Recovers the necessary variables by `__globals__`. # Recovers the necessary variables by `__globals__`.
......
...@@ -28,6 +28,7 @@ from paddle.fluid import unique_name ...@@ -28,6 +28,7 @@ from paddle.fluid import unique_name
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.layers import nn from paddle.fluid.layers import nn
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import in_dygraph_mode
__all__ = ['TranslatedLayer'] __all__ = ['TranslatedLayer']
...@@ -163,10 +164,17 @@ def _get_loaded_var_new_old(program_desc, all_new_old_dict_all): ...@@ -163,10 +164,17 @@ def _get_loaded_var_new_old(program_desc, all_new_old_dict_all):
return new_old_dict return new_old_dict
def _rename_var_program_desc(program_desc): def _rename_var_program_desc(program_desc, include=None, exclude=None):
""" """
Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication
e.g. x ==> x_0, x_0 ==> x_1 e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0.
If 'include' is not `None`,variables that are not in include are not renamed.
If 'exclude' is not `None`,variables that are in exclude will are not renamed.
Args:
program_desc(ProgramDesc):the variables in it will be modified.
include(List):list of names of variables.
exclude(List):list of names of variables.
""" """
dict_rename_var_old_new = dict() dict_rename_var_old_new = dict()
dict_rename_var_new_old = dict() dict_rename_var_new_old = dict()
...@@ -175,25 +183,26 @@ def _rename_var_program_desc(program_desc): ...@@ -175,25 +183,26 @@ def _rename_var_program_desc(program_desc):
cur_block = program_desc.block(b_idx) cur_block = program_desc.block(b_idx)
for var in cur_block.all_vars(): for var in cur_block.all_vars():
old_names.append(var.name()) old_names.append(var.name())
persistable_vars = _get_persistable_vars(program_desc)
for b_idx in six.moves.range(program_desc.num_blocks()): for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx) cur_block = program_desc.block(b_idx)
for var_idx, var in enumerate(cur_block.all_vars()): for var_idx, var in enumerate(cur_block.all_vars()):
if var not in persistable_vars:
continue
name_old = var.name() name_old = var.name()
while True: should_rename = (include is None or name_old in include) and (
exclude is None or name_old not in exclude)
if should_rename:
temp_name = name_old.split('_') temp_name = name_old.split('_')
if len(temp_name) > 1 and temp_name[-1].isnumeric(): if len(temp_name) > 1 and temp_name[-1].isnumeric():
temp_name = "_".join(temp_name[:-1]) temp_name = "_".join(temp_name[:-1])
else: else:
temp_name = "_".join(temp_name) temp_name = name_old
while True:
name_new = _generate_unique_var_name_sync_with_main_program( name_new = _generate_unique_var_name_sync_with_main_program(
temp_name) temp_name)
if name_new not in old_names[:var_idx] + old_names[var_idx + if name_new not in old_names[:var_idx] + old_names[var_idx +
1:]: 1:]:
break break
else:
name_new = name_old
if name_old != name_new: if name_old != name_new:
cur_block._rename_var( cur_block._rename_var(
cpt.to_bytes(name_old), cpt.to_bytes(name_new)) cpt.to_bytes(name_old), cpt.to_bytes(name_new))
...@@ -300,8 +309,10 @@ class _ProgramHolder(object): ...@@ -300,8 +309,10 @@ class _ProgramHolder(object):
return self._inner_scope return self._inner_scope
def _preprocess(self, program_desc): def _preprocess(self, program_desc):
# rename variables of 'program_desc' # rename persistable variables of 'program_desc'
rename_new_old_dict, _ = _rename_var_program_desc(program_desc) list_persistable_var = _get_persistable_var_names(program_desc)
rename_new_old_dict, _ = _rename_var_program_desc(program_desc,
list_persistable_var)
# 1. Prune original program # 1. Prune original program
# remove feed, fetch and scale-1 op, remove op_callstack attr # remove feed, fetch and scale-1 op, remove op_callstack attr
ops_to_remove = [] ops_to_remove = []
...@@ -645,6 +656,327 @@ def _construct_params_and_buffers(model_path, ...@@ -645,6 +656,327 @@ def _construct_params_and_buffers(model_path,
return var_dict return var_dict
def _run_dygraph(instance, input, program_holder):
# 1. prepare inputs, outputs, attrs
input_vars = []
for i, value in enumerate(input):
if not isinstance(value, (np.ndarray, core.VarBase)):
raise TypeError(
"The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s."
% type(value))
# NOTE: In order to unify the API, firstly convert the input to VarBase
if isinstance(value, np.ndarray):
var = core.VarBase(
value=value,
name=program_holder.input_descs[i].name(),
persistable=False,
place=framework._current_expected_place(),
zero_copy=True)
else:
var = value
# NOTE: we changed var name here,
# but it may be an important name set by user
var.name = program_holder.input_descs[i].name()
input_vars.append(var)
if instance._input_args_names is None:
instance._input_args_names = [
ins.name() for ins in program_holder.input_descs
]
persistable_vars = []
for var_name in program_holder.persistable_names:
dy_var_name = instance._persistable_var_name_dict[var_name]
if dy_var_name in instance._parameters:
persistable_vars.append(instance._parameters[dy_var_name])
elif dy_var_name in instance._buffers:
persistable_vars.append(instance._buffers[dy_var_name])
else:
raise ValueError(
"The persistable variable %s does not exist in current TranslatedLayer."
% var_name)
output_vars = []
for var_desc in program_holder.output_descs:
var = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(), var_desc.type(), False)
output_vars.append(var)
# hold forward variables
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(program_holder.scope)
# 2. run program by op
trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program
end_op_index = program_holder.infer_program.block(0).op_size()
framework._dygraph_tracer().trace_op(
type='run_program',
inputs={'X': input_vars,
'Params': persistable_vars},
outputs={'Out': output_vars,
'OutScope': tmp_scope_vec},
attrs={
'global_block': trace_program.block(0),
'start_op_index': 0,
'end_op_index': end_op_index,
'is_test': instance._is_test
})
# NOTE: [ why need set param's gradient type here ]
# if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
# set param grad VarBase by forward VarBase(LoDTensor)
# If we don't change grad_var type here, RunProgramOp need
# transform SelectedRows to LoDTensor forcibly, it may not
# be user wanted result.
for persistable_var in persistable_vars:
grad_var_name = var.name + core.grad_var_suffix()
grad_var = trace_program.block(0).find_var(cpt.to_bytes(grad_var_name))
# NOTE: cannot find var desc maybe not problem,
# such as in batch_norm
if grad_var is None:
continue
persistable_var._set_grad_type(grad_var.type())
# 3. prepare output, keep same form with inputs
outs = output_vars
if len(output_vars) == 1:
outs = output_vars[0]
return outs
def _run_static_graph(input, program_holder, trace_program):
main_program = framework.default_main_program()
param_var_names = _get_persistable_var_names(trace_program)
_, dict_rename_var_old_new = _rename_var_program_desc(
trace_program, exclude=param_var_names)
trace_program.flush()
output_names = [var.name() for var in program_holder.output_descs]
# append blocks from 'trace_program'
_append_block(main_program, trace_program, program_holder, input,
dict_rename_var_old_new)
main_program._sync_with_cpp()
outs = _get_output_from_program(main_program, program_holder,
dict_rename_var_old_new)
if len(outs) == 1:
outs = outs[0]
return outs
def _collect_current_and_parent_var(program, block_idx):
'''
Get variables in current block and its parent block.
Args:
program(Program): The program containing the current block.
block_idx(int): index of current block.
Returns:
List: list of variables.
'''
vars = []
if block_idx < 0:
return vars
for var in program.block(block_idx).vars:
vars.append(var)
parent_idx = program.block(block_idx).parent_idx
if parent_idx > -1:
vars += _collect_current_and_parent_var(program, parent_idx)
return vars
def _append_block(dest_program,
src_program_desc,
program_holder,
input_variables,
dict_rename_var_old_new=None):
'''
Append Variables and Operators in 'src_program_desc' to dest_program.
Args:
dest_program(Program): Variables and Operators are appended to it.
src_program_desc(ProgramDesc): Variables in it will be appended to 'dest_program'.
program_holder(_ProgramHolder): program_holder of TranslatedLayer
input_variables(list): list of input variables
dict_rename_var_old_new(None|dict): When using '_rename_var_program_desc',
use it to map the name of the variable before it was modified and the new name.
'''
origin_block_idx = dest_program.current_block_idx
param_var_names = _collect_current_and_parent_var(dest_program,
origin_block_idx)
append_var_from_block_desc_static(
dest_program.block(origin_block_idx),
src_program_desc.block(0),
exclude=param_var_names)
name_inp_desc = [inp.name() for inp in program_holder.input_descs]
input_names = [inp.name for inp in input_variables]
if len(name_inp_desc) != len(input_names):
raise ValueError(
"The number of input is invalid, expected {}, but received {}.".
format(len(name_inp_desc), len(input_names)))
for i, out_name in enumerate(name_inp_desc):
if dict_rename_var_old_new:
out_name = dict_rename_var_old_new[out_name]
dest_program.block(origin_block_idx).append_op(
type='assign',
inputs={'X': [input_names[i]]},
outputs={'Out': [out_name]})
append_ops = append_op_from_block_desc_static(
dest_program.block(origin_block_idx), src_program_desc.block(0))
dest_program._sync_with_cpp()
offset_block_idx = dest_program.num_blocks - 1
if src_program_desc.num_blocks() > 1:
for src_block_idx in range(1, src_program_desc.num_blocks()):
src_block = src_program_desc.block(src_block_idx)
src_parent_idx = src_block.parent
if src_parent_idx > 0:
parent_idx = offset_block_idx + parent_idx
else:
parent_idx = origin_block_idx
dest_block = dest_program._create_block(parent_idx=parent_idx)
append_var_from_block_desc_static(
dest_block, src_block, exclude=param_var_names)
append_ops += append_op_from_block_desc_static(dest_block,
src_block)
dest_program._sync_with_cpp()
for op in append_ops:
if op.has_attr('sub_block'):
sub = op.attr('sub_block')
if isinstance(sub, framework.core.BlockDesc):
origin_id = sub.id
if isinstance(sub, framework.Block):
origin_id = sub.idx
op._set_attr('sub_block',
dest_program.block(offset_block_idx + origin_id))
dest_program._sync_with_cpp()
dest_program.current_block_idx = origin_block_idx
def _get_output_from_program(program,
program_holder,
dict_rename_var_old_new=None):
"""
Get output name of 'program' according to program_holder
"""
outs = list()
for var in program_holder.output_descs:
for idx in range(program.num_blocks):
vars = program.block(idx).vars
var_name = var.name()
if dict_rename_var_old_new:
var_name = dict_rename_var_old_new[var_name]
if var_name in vars:
out = vars[var_name]
if out not in outs:
outs.append(out)
return outs
def append_op_from_block_desc_static(block, src_block_desc):
"""
Append Operators of 'src_block_desc' to current block.
Args:
block(Block): append OP of 'src_block_desc' to it.
src_block_desc(BlockDesc): append var of 'src_block_desc'
Returns:
List: list of the OP that are append to current block.
"""
ops = []
for i in range(src_block_desc.op_size()):
ops.append(append_op_from_desc_static(block, src_block_desc.op(i)))
return ops
def append_op_from_desc_static(block, op_desc):
"""
Append Operators to 'block' according to 'op_desc'.
Args:
block(Block): append OP of 'src_block_desc' to it.
op_desc(OpDesc): create OP according to it.
Returns:
Operator: OP appended to 'block'.
"""
op_type = op_desc.type()
op_append = block.desc.append_op()
op_append.copy_from(op_desc)
op = framework.Operator(
block=block,
desc=op_append,
type=op_type,
inputs=None,
outputs=None,
attrs=None)
block.ops.append(op)
return op
def append_var_from_block_desc_static(block,
src_block_desc,
include=None,
exclude=None):
"""
Append Variables of 'src_block_desc' to current block.
If 'include' is not `None`,variables that are not in include are not append.
If 'exclude' is not `None`,variables that are in exclude will are not append.
Args:
block(Block): append Variables of 'src_block_desc' to it.
src_block_desc(BlockDesc): append var of 'src_block_desc'
include(List):list of names of variables
exclude(List):list of names of variables
Returns:
List: list of the variables that are append to current block.
"""
vars_append = []
for var_desc in src_block_desc.all_vars():
var_desc_name = var_desc.name()
should_append = (include is None or var_desc_name in include) and (
exclude is None or var_desc_name not in exclude)
if not block.has_var(var_desc_name) and should_append:
var_type = var_desc.type()
if var_type in [
core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]:
data_type = var_desc.dtype()
var_shape = var_desc.shape()
else:
data_type = None
var_shape = None
if var_type in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]:
lod_level = var_desc.lod_level()
else:
lod_level = None
vars_append.append(
block.create_var(
name=var_desc.name(),
dtype=data_type,
type=var_type,
shape=var_shape,
lod_level=lod_level,
persistable=var_desc.persistable(),
set_need_check_feed=var_desc.need_check_feed()))
return vars_append
class TranslatedLayer(layers.Layer): class TranslatedLayer(layers.Layer):
""" """
TranslatedLayer is a ``paddle.nn.Layer`` for holding the model TranslatedLayer is a ``paddle.nn.Layer`` for holding the model
...@@ -780,6 +1112,7 @@ class TranslatedLayer(layers.Layer): ...@@ -780,6 +1112,7 @@ class TranslatedLayer(layers.Layer):
) )
self._is_test = True self._is_test = True
self._input_args_names = None
@staticmethod @staticmethod
@framework.dygraph_only @framework.dygraph_only
...@@ -817,95 +1150,23 @@ class TranslatedLayer(layers.Layer): ...@@ -817,95 +1150,23 @@ class TranslatedLayer(layers.Layer):
@staticmethod @staticmethod
def _execution_method_creator(method_name, program_holder): def _execution_method_creator(method_name, program_holder):
def __impl__(self, *input): def __i_m_p_l__(self, *input):
# 1. prepare inputs, outputs, attrs program_holder = self._program_holder_dict[__i_m_p_l__.__name__]
input_vars = [] # When using jit.save, it runs in static graph mode.
for i, value in enumerate(input): # Run in dynamic graph mode when the model is inferring.
if not isinstance(value, (np.ndarray, core.VarBase)): if in_dygraph_mode():
raise TypeError( return _run_dygraph(self, input, program_holder)
"The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s."
% type(value))
# NOTE: In order to unify the API, firstly convert the input to VarBase
if isinstance(value, np.ndarray):
var = core.VarBase(
value=value,
name=program_holder.input_descs[i].name(),
persistable=False,
place=framework._current_expected_place(),
zero_copy=True)
else:
var = value
# NOTE: we changed var name here,
# but it may be an important name set by user
var.name = program_holder.input_descs[i].name()
input_vars.append(var)
persistable_vars = []
for var_name in program_holder.persistable_names:
dy_var_name = self._persistable_var_name_dict[var_name]
if dy_var_name in self._parameters:
persistable_vars.append(self._parameters[dy_var_name])
elif dy_var_name in self._buffers:
persistable_vars.append(self._buffers[dy_var_name])
else: else:
raise ValueError( # NOTE(weixin): [ why not use 'program_holder.infer_program' directly? ]
"The persistable variable %s is not exists in current TranslatedLayer." # When use '_run_static_graph(input, program_holder, program_holder.infer_program)',
% var_name) # because '_run_static_graph' modifies 'ProgramDesc', 'OpDesc.op_size()' will return a very large wrong number.
# A Segmentation fault error may occur if used 'p=ProgramDesc(program_holder.infer_program)'.
output_vars = [] p = framework.Program._construct_from_desc(
for var_desc in program_holder.output_descs: core.ProgramDesc(program_holder.infer_program))
var = core.VarBase(var_desc.dtype(), return _run_static_graph(input, program_holder, p.desc)
var_desc.shape(),
var_desc.name(), var_desc.type(), False) __i_m_p_l__.__name__ = method_name
output_vars.append(var) return __i_m_p_l__
# hold forward variables
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(program_holder.scope)
# 2. run program by op
trace_program = program_holder.infer_program if self._is_test else program_holder.train_program
end_op_index = program_holder.infer_program.block(0).op_size()
framework._dygraph_tracer().trace_op(
type='run_program',
inputs={'X': input_vars,
'Params': persistable_vars},
outputs={'Out': output_vars,
'OutScope': tmp_scope_vec},
attrs={
'global_block': trace_program.block(0),
'start_op_index': 0,
'end_op_index': end_op_index,
'is_test': self._is_test
})
# NOTE: [ why need set param's gradient type here ]
# if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
# set param grad VarBase by forward VarBase(LoDTensor)
# If we don't change grad_var type here, RunProgramOp need
# transform SelectedRows to LoDTensor forcibly, it may not
# be user wanted result.
for persistable_var in persistable_vars:
grad_var_name = var.name + core.grad_var_suffix()
grad_var = trace_program.block(0).find_var(
cpt.to_bytes(grad_var_name))
# NOTE: cannot find var desc maybe not problem,
# such as in batch_norm
if grad_var is None:
continue
persistable_var._set_grad_type(grad_var.type())
# 3. prepare output, keep same form with inputs
outs = output_vars
if len(output_vars) == 1:
outs = output_vars[0]
return outs
__impl__.__name__ = method_name
return __impl__
def train(self): def train(self):
self._is_test = False self._is_test = False
......
...@@ -25,6 +25,7 @@ from paddle.fluid.layers.utils import flatten ...@@ -25,6 +25,7 @@ from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph import declarative, ProgramTranslator from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
from paddle.fluid import unique_name
BATCH_SIZE = 32 BATCH_SIZE = 32
BATCH_NUM = 10 BATCH_NUM = 10
...@@ -863,6 +864,94 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase): ...@@ -863,6 +864,94 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase):
layer, model_path, input_spec=[InputSpec(shape=[None, 784])]) layer, model_path, input_spec=[InputSpec(shape=[None, 784])])
class LayerSaved(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(LayerSaved, self).__init__()
self.hidden = 100
self._linear_0 = Linear(in_size, self.hidden)
self._linear_1_0 = Linear(self.hidden, self.hidden)
self._linear_1_1 = Linear(self.hidden, self.hidden)
self._linear_2 = Linear(self.hidden, out_size)
self._scale = paddle.to_tensor(9.9)
@paddle.jit.to_static
def forward(self, x):
y = self._linear_0(x)
# Multiple blocks
if x.shape[0] == 1:
y = self._linear_1_0(y)
else:
y += self._linear_1_1(y + self._scale)
return self._linear_2(y)
class LayerLoadFinetune(paddle.nn.Layer):
def __init__(self, in_size, out_size, load_path):
super(LayerLoadFinetune, self).__init__()
# Test duplicate name
self._linear_0 = Linear(in_size, in_size)
self._linear_1_0 = Linear(out_size, in_size)
self._linear_1_1 = Linear(out_size, in_size)
self._linear_2 = Linear(out_size, out_size)
self._scale = paddle.to_tensor(9.9)
# Load multiple times
self._load_l1 = paddle.jit.load(load_path)
self._load_l2 = paddle.jit.load(load_path)
@paddle.jit.to_static
def forward(self, x):
y = self._linear_0(x)
y = self._load_l1(y)
# Multiple blocks
if x.shape[0] == 1:
y = self._linear_1_0(y)
y = self._load_l1(y)
else:
y += self._linear_1_1(x + self._scale)
y = self._load_l2(y)
y = self._linear_1_0(y)
y = self._load_l1(y)
y = self._linear_1_0(y)
# Use the same layer multiple times.
y = self._load_l1(y)
return y
class TestJitSaveLoadFinetuneLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()
def test_save_load_finetune_load(self):
model_path = "test_jit_save_load_finetune_load/model"
IMAGE_SIZE = 224
inps0 = paddle.randn([1, IMAGE_SIZE])
inps1 = paddle.randn([2, IMAGE_SIZE])
# Use new namespace
with unique_name.guard():
layer_save = LayerSaved(IMAGE_SIZE, IMAGE_SIZE)
layer_save(inps0)
#save
paddle.jit.save(layer_save, model_path)
#load
with unique_name.guard():
layer_load = LayerLoadFinetune(IMAGE_SIZE, IMAGE_SIZE, model_path)
#train
train(layer_load, input_size=IMAGE_SIZE)
result_00 = layer_load(inps0)
result_01 = layer_load(inps1)
#save
paddle.jit.save(layer_load, model_path)
#load
layer_finetune = paddle.jit.load(model_path)
result_10 = layer_finetune(inps0)
result_11 = layer_finetune(inps1)
self.assertTrue(float((result_00 - result_10).abs().max()) < 1e-5)
self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5)
class TestJitSaveLoadDataParallel(unittest.TestCase): class TestJitSaveLoadDataParallel(unittest.TestCase):
def verify_inference_correctness(self, layer, path): def verify_inference_correctness(self, layer, path):
layer.eval() layer.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册