未验证 提交 c09de13e 编写于 作者: C Chen Weihang 提交者: GitHub

Refine jit load model by extra_var_info (#26461)

* refine load model by extra_var_info

* polish unittest for coverage
上级 83cd1859
......@@ -437,8 +437,16 @@ def _load_persistable_vars(model_path,
value: key
for key, value in program_holder._suffix_varname_dict.items()
}
# NOTE: some var may not be Parameter
for name in sorted(extra_var_info):
# NOTE(chenweihang): we need load persistable vars based the program,
# because the program may be pruned when `save_inference_model`, some
# var in `extra_var_info` may have been pruned
for name in sorted(inv_suffix_varname_dict):
if name not in extra_var_info:
raise RuntimeError(
"The model to be loaded is not complete."
"The variable `%s` of program cannot be found in loaded model.",
name)
# get suffix var name, see [why need to append suffix to persistable vars]
new_name = inv_suffix_varname_dict[name]
# create output varbase
......@@ -641,19 +649,21 @@ class TranslatedLayer(layers.Layer):
# name contains `.` originally, such as `linear_0.w_0`, so here
# need to generate new var name for each var
self._persistable_var_name_dict = dict()
for name, var in persistable_vars.items():
if isinstance(var, framework.ParamBase):
dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
self._persistable_var_name_dict[name] = dy_name
self.add_parameter(dy_name, var)
elif isinstance(var, core.VarBase):
dy_name = _generate_unique_var_name(BUFFER_NAME_PREFIX)
self._persistable_var_name_dict[name] = dy_name
self.register_buffer(dy_name, var)
else:
raise TypeError(
"Adding persistent variable which to layer is not supported now"
)
# the TranslatedLayer object holded var names count started from 0
with unique_name.guard():
for name, var in persistable_vars.items():
if isinstance(var, framework.ParamBase):
dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
self._persistable_var_name_dict[name] = dy_name
self.add_parameter(dy_name, var)
elif isinstance(var, core.VarBase):
dy_name = _generate_unique_var_name(BUFFER_NAME_PREFIX)
self._persistable_var_name_dict[name] = dy_name
self.register_buffer(dy_name, var)
else:
raise TypeError(
"Adding persistent variable which to layer is not supported now"
)
self._is_test = True
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import os
import pickle
import unittest
import numpy as np
......@@ -25,7 +26,7 @@ from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME
BATCH_SIZE = 32
BATCH_NUM = 20
BATCH_NUM = 10
SEED = 10
......@@ -318,5 +319,76 @@ class TestJitMultipleLoading(unittest.TestCase):
name_set.add(var.name)
class LinearNetReturnHidden(fluid.dygraph.Layer):
def __init__(self, in_size, out_size):
super(LinearNetReturnHidden, self).__init__()
self._linear_1 = Linear(in_size, out_size)
self._linear_2 = Linear(in_size, out_size)
@declarative
def forward(self, x):
y = self._linear_1(x)
z = self._linear_2(y)
loss = fluid.layers.mean(z)
return y, loss
class TestJitPruneModelAndLoad(unittest.TestCase):
def setUp(self):
self.linear_size = 4
self.model_path = "model.jit_prune_model_and_load"
# enable dygraph mode
fluid.enable_dygraph()
# config seed
fluid.default_main_program().random_seed = SEED
def train_and_save(self):
train_layer = LinearNetReturnHidden(8, 8)
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.1, parameter_list=train_layer.parameters())
x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32'))
for i in range(10):
hidden, loss = train_layer(x)
loss.backward()
adam.minimize(loss)
train_layer.clear_gradients()
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.output_spec = [hidden]
fluid.dygraph.jit.save(
layer=train_layer,
model_path=self.model_path,
input_spec=[x],
configs=configs)
return train_layer
def test_load_pruned_model(self):
train_layer = self.train_and_save()
train_layer.eval()
infer_layer = fluid.dygraph.jit.load(self.model_path)
x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32'))
self.assertTrue(
np.array_equal(train_layer(x)[0].numpy(), infer_layer(x).numpy()))
def test_load_var_not_in_extra_var_info(self):
self.train_and_save()
# chage extra var info
var_info_path = os.path.join(self.model_path, EXTRA_VAR_INFO_FILENAME)
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
extra_var_info.clear()
with open(var_info_path, 'wb') as f:
pickle.dump(extra_var_info, f, protocol=2)
with self.assertRaises(RuntimeError):
fluid.dygraph.jit.load(self.model_path)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册