未验证 提交 2bf57450 编写于 作者: H Hui Zhang 提交者: GitHub

[jit] jit.save support property serialization (#44581)

* jit.save support peropty serilization

* extract set property function

* fix property test file name

* fix typing error

* fix typing error

* fix test coverage
上级 0dae79a9
......@@ -433,15 +433,6 @@ void BindJitProperty(pybind11::module *m) {
"set list of string",
py::arg("name"),
py::arg("val"))
.def("set_tensor",
[](const pd::VarDesc &tensor, const std::string name) {
throw platform::errors::Unimplemented("Not implement set_tensor.");
})
.def(
"set_tensors",
[](const pybind11::list &tensors, const std::string name) {
throw platform::errors::Unimplemented("Not implement set_tensors.");
})
.def("serialize_to_string", SerializeMessage<jit::Property>)
.def("parse_from_string", DeserializeMessage<jit::Property>);
}
......
......@@ -37,6 +37,7 @@ __all__ = ['TranslatedLayer']
INFER_MODEL_SUFFIX = ".pdmodel"
INFER_PARAMS_SUFFIX = ".pdiparams"
INFER_PARAMS_INFO_SUFFIX = ".pdiparams.info"
INFER_PROPERTY_SUFFIX = '.meta'
LOADED_VAR_SUFFIX = "load"
PARAMETER_NAME_PREFIX = "param"
......
......@@ -22,6 +22,7 @@ import functools
from collections import OrderedDict
import inspect
import threading
from typing import Text, Tuple, Any, List
import six
import paddle
......@@ -34,7 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import ConversionOptions, CONVERSION_OPTIONS
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticFunction, unwrap_decorators
from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX, INFER_PROPERTY_SUFFIX
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.framework import Block, ParamBase, Program, Variable, Parameter, EagerParamBase
......@@ -644,6 +645,40 @@ def _run_save_pre_hooks(func):
return wrapper
def _save_property(filename: Text, property_vals: List[Tuple[Any, Text]]):
"""class property serialization.
Args:
filename (Text): *.meta
property_vals (List[Tuple): class property.
"""
def set_property(meta, key, val):
if isinstance(val, float):
meta.set_float(key, val)
elif isinstance(val, int):
meta.set_int(key, val)
elif isinstance(val, str):
meta.set_string(key, val)
elif isinstance(val, (tuple, list)):
if isinstance(val[0], float):
meta.set_floats(key, val)
elif isinstance(val[0], int):
meta.set_ints(key, val)
elif isinstance(val[0], str):
meta.set_strings(key, val)
else:
raise ValueError(f"Note support val type: {type(val)}")
return
with open(filename, 'wb') as f:
meta = paddle.framework.core.Property()
for item in property_vals:
val, key = item[0], item[1]
set_property(meta, key, val)
f.write(meta.serialize_to_string())
@_run_save_pre_hooks
@switch_to_static_graph
def save(layer, path, input_spec=None, **configs):
......@@ -1043,7 +1078,9 @@ def save(layer, path, input_spec=None, **configs):
filter(paddle.fluid.io.is_persistable,
ordered_vars)),
filename=params_filename)
# TODO: save property
# save property
property_filename = file_prefix + INFER_PROPERTY_SUFFIX
_save_property(property_filename, property_vals)
# NOTE(chenweihang): [ Save extra variable info ]
# save_inference_model will lose some important variable information, including:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -1156,7 +1157,7 @@ class LayerSaved(paddle.nn.Layer):
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
super().__init__()
self.fc1 = paddle.nn.Linear(4, 4)
self.fc2 = paddle.nn.Linear(4, 4)
self.bias = 0.4
......@@ -1185,13 +1186,49 @@ class Net(paddle.nn.Layer):
def fbias(self):
return self.bias + 1
# For extra Tensor
@paddle.jit.to_static(property=True)
def down_sampling(self):
return 4
@paddle.jit.to_static(property=True)
def fstr(self):
return "save str property"
@paddle.jit.to_static(property=True)
def ints(self):
return [10, 20]
@paddle.jit.to_static(property=True)
def floats(self):
return [1.1, 2.2]
@paddle.jit.to_static(property=True)
def strs(self):
return ["hello", "world"]
class NetTensor(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc1 = paddle.nn.Linear(4, 4)
self.fc2 = paddle.nn.Linear(4, 4)
self.bias = 0.4
self.flag = paddle.ones([2], dtype="int32")
@paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')])
def forward(self, x):
out = self.fc1(x)
out = paddle.nn.functional.relu(out)
out = paddle.mean(out)
return out
@paddle.jit.to_static(property=True)
def fflag(self):
return self.flag
return True
class TestJitSaveCombine(unittest.TestCase):
class TestJitSaveCombineProperty(unittest.TestCase):
def setUp(self):
# enable dygraph mode
......@@ -1201,16 +1238,24 @@ class TestJitSaveCombine(unittest.TestCase):
def tearDown(self):
self.temp_dir.cleanup()
def test_save_load_finetune_load(self):
def test_jit_save_combine_property(self):
model_path = os.path.join(self.temp_dir.name,
"test_jit_save_combine/model")
# Use new namespace
with unique_name.guard():
net = Net()
#save
paddle.jit.save(net, model_path, combine_params=True)
def test_jit_save_tensor_property(self):
model_path = os.path.join(self.temp_dir.name,
"test_jit_save_combine/model")
# Use new namespace
with unique_name.guard():
net = NetTensor()
paddle.jit.save(net, model_path, combine_params=True)
class LayerLoadFinetune(paddle.nn.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册