未验证 提交 8a7dee31 编写于 作者: J jiangcheng 提交者: GitHub

graph_to_program save parameter and stop_gradient information (#33771)

This PR added optional boolean is_parameter and stop_gradient in the VarDesc proto, and remove them during save_inference_model
上级 a59f215d
...@@ -173,6 +173,8 @@ message VarDesc { ...@@ -173,6 +173,8 @@ message VarDesc {
// True if the variable is an input data and // True if the variable is an input data and
// have to check the feed data shape and dtype // have to check the feed data shape and dtype
optional bool need_check_feed = 4 [ default = false ]; optional bool need_check_feed = 4 [ default = false ];
optional bool is_parameter = 5 [ default = false ];
optional bool stop_gradient = 6 [ default = false ];
} }
message BlockDesc { message BlockDesc {
......
...@@ -111,6 +111,26 @@ class VarDesc { ...@@ -111,6 +111,26 @@ class VarDesc {
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
bool IsParameter() const { return desc_.is_parameter(); }
void SetIsParameter(bool is_parameter) {
desc_.set_is_parameter(is_parameter);
}
void ClearIsParameter() { desc_.clear_is_parameter(); }
bool HasIsParameter() const { return desc_.has_is_parameter(); }
bool StopGradient() const { return desc_.stop_gradient(); }
void SetStopGradient(bool stop_gradient) {
desc_.set_stop_gradient(stop_gradient);
}
void ClearStopGradient() { desc_.clear_stop_gradient(); }
bool HasStopGradient() const { return desc_.has_stop_gradient(); }
bool NeedCheckFeed() const { return desc_.need_check_feed(); } bool NeedCheckFeed() const { return desc_.need_check_feed(); }
void SetNeedCheckFeed(bool need_check_feed) { void SetNeedCheckFeed(bool need_check_feed) {
......
...@@ -175,6 +175,14 @@ void BindVarDsec(pybind11::module *m) { ...@@ -175,6 +175,14 @@ void BindVarDsec(pybind11::module *m) {
.def("serialize_to_string", SerializeMessage<pd::VarDesc>) .def("serialize_to_string", SerializeMessage<pd::VarDesc>)
.def("persistable", &pd::VarDesc::Persistable) .def("persistable", &pd::VarDesc::Persistable)
.def("set_persistable", &pd::VarDesc::SetPersistable) .def("set_persistable", &pd::VarDesc::SetPersistable)
.def("is_parameter", &pd::VarDesc::IsParameter)
.def("set_is_parameter", &pd::VarDesc::SetIsParameter)
.def("clear_is_parameter", &pd::VarDesc::ClearIsParameter)
.def("has_is_parameter", &pd::VarDesc::HasIsParameter)
.def("stop_gradient", &pd::VarDesc::StopGradient)
.def("set_stop_gradient", &pd::VarDesc::SetStopGradient)
.def("clear_stop_gradient", &pd::VarDesc::ClearStopGradient)
.def("has_stop_gradient", &pd::VarDesc::HasStopGradient)
.def("need_check_feed", &pd::VarDesc::NeedCheckFeed) .def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed); .def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed);
......
...@@ -1203,7 +1203,7 @@ class Executor(object): ...@@ -1203,7 +1203,7 @@ class Executor(object):
if vardesc.persistable() == False and \ if vardesc.persistable() == False and \
vardesc.type() == core.VarDesc.VarType.LOD_TENSOR and \ vardesc.type() == core.VarDesc.VarType.LOD_TENSOR and \
vardesc.need_check_feed() == True and \ vardesc.need_check_feed() == True and \
varobj._stop_gradient == True and \ varobj.stop_gradient == True and \
varobj.is_data == True and \ varobj.is_data == True and \
varobj.belong_to_optimizer == False and \ varobj.belong_to_optimizer == False and \
varname not in feed: varname not in feed:
......
...@@ -945,7 +945,7 @@ class Variable(object): ...@@ -945,7 +945,7 @@ class Variable(object):
self.block.vars[name] = self self.block.vars[name] = self
self.op = None self.op = None
self._stop_gradient = stop_gradient self.stop_gradient = stop_gradient
self.is_data = is_data self.is_data = is_data
def detach(self): def detach(self):
...@@ -1182,7 +1182,7 @@ class Variable(object): ...@@ -1182,7 +1182,7 @@ class Variable(object):
var_str = "{name} : {type})".\ var_str = "{name} : {type})".\
format(name=self.name, type=type_str) format(name=self.name, type=type_str)
if type(self) == Parameter: if self.is_parameter:
if self.trainable: if self.trainable:
var_str = "trainable param " + var_str var_str = "trainable param " + var_str
else: else:
...@@ -1230,7 +1230,7 @@ class Variable(object): ...@@ -1230,7 +1230,7 @@ class Variable(object):
proto = framework_pb2.VarDesc.FromString(six.binary_type(protostr)) proto = framework_pb2.VarDesc.FromString(six.binary_type(protostr))
res_str = _debug_string_(proto, throw_on_error) res_str = _debug_string_(proto, throw_on_error)
if with_details: if with_details:
additional_attr = ("error_clip", "stop_gradient") additional_attr = ("error_clip", )
for attr_name in additional_attr: for attr_name in additional_attr:
res_str += "%s: %s\n" % (attr_name, res_str += "%s: %s\n" % (attr_name,
cpt.to_text(getattr(self, attr_name))) cpt.to_text(getattr(self, attr_name)))
...@@ -1270,11 +1270,11 @@ class Variable(object): ...@@ -1270,11 +1270,11 @@ class Variable(object):
assert linear.weight.gradient() is None assert linear.weight.gradient() is None
assert (out1.gradient() == 0).all() assert (out1.gradient() == 0).all()
""" """
return self._stop_gradient return self.desc.stop_gradient()
@stop_gradient.setter @stop_gradient.setter
def stop_gradient(self, s): def stop_gradient(self, s):
self._stop_gradient = s self.desc.set_stop_gradient(s)
@property @property
def persistable(self): def persistable(self):
...@@ -1305,6 +1305,31 @@ class Variable(object): ...@@ -1305,6 +1305,31 @@ class Variable(object):
def persistable(self, p): def persistable(self, p):
self.desc.set_persistable(p) self.desc.set_persistable(p)
@property
def is_parameter(self):
"""
Indicating if current Variable is a Parameter
Examples:
.. code-block:: python
import paddle
new_parameter = paddle.static.create_parameter(name="X",
shape=[10, 23, 48],
dtype='float32')
if new_parameter.is_parameter:
print("Current var is a Parameter")
else:
print("Current var is not a Parameter")
# Current var is a Parameter
"""
return self.desc.is_parameter()
@is_parameter.setter
def is_parameter(self, p):
self.desc.set_is_parameter(p)
@property @property
def name(self): def name(self):
""" """
...@@ -2863,12 +2888,7 @@ class Block(object): ...@@ -2863,12 +2888,7 @@ class Block(object):
param = ParamBase(*args, **kwargs) param = ParamBase(*args, **kwargs)
else: else:
param = Parameter(global_block, *args, **kwargs) param = Parameter(global_block, *args, **kwargs)
# NOTE: Why only set stop_gradient=False in static mode
# Because in dygraph mode, the `stop_gradient` and `trainable`
# are related, and `trainable` default vallue is `True` or
# it is specified by users, there is no need to set
# `stop_gradient` for ParamBase here.
param.stop_gradient = False
if 'initializer' in kwargs: if 'initializer' in kwargs:
def _is_inited_by(block, var): def _is_inited_by(block, var):
...@@ -3041,7 +3061,23 @@ class Block(object): ...@@ -3041,7 +3061,23 @@ class Block(object):
# sync variables from cpp # sync variables from cpp
for var in self.desc.all_vars(): for var in self.desc.all_vars():
if not self.has_var(var.name()): if not self.has_var(var.name()):
self.create_var(name=var.name(), desc=var, type=var.type()) is_stop_gradient = False
if var.has_stop_gradient():
is_stop_gradient = var.stop_gradient()
if var.has_is_parameter() and var.is_parameter():
self.create_parameter(
name=var.name(),
desc=var,
type=var.type(),
shape=var.shape(),
dtype=var.dtype(),
stop_gradient=is_stop_gradient)
else:
self.create_var(
name=var.name(),
desc=var,
type=var.type(),
stop_gradient=is_stop_gradient)
# sync variables removed from c++ end # sync variables removed from c++ end
for var in list(self.vars.keys()): for var in list(self.vars.keys()):
...@@ -4752,6 +4788,33 @@ class Program(object): ...@@ -4752,6 +4788,33 @@ class Program(object):
res._sync_with_cpp() res._sync_with_cpp()
return res return res
def _remove_training_info(self):
"""
This method will create a new program and do following adjustments on it:
1. Remove all variable's `is_parameter` attribute if exist.
2. Remove all variable's `stop_gradient` attribute if exist.
Notes: This API is a very low level API.
Returns:
Program: The new program.
"""
res = Program()
res.desc = core.ProgramDesc(self.desc)
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp()
for i in six.moves.range(res.desc.num_blocks()):
block = res.desc.block(i)
for var in block.all_vars():
var.clear_is_parameter()
var.clear_stop_gradient()
return res
@staticmethod @staticmethod
def parse_from_string(binary_str): def parse_from_string(binary_str):
""" """
...@@ -5399,6 +5462,8 @@ class Parameter(Variable): ...@@ -5399,6 +5462,8 @@ class Parameter(Variable):
self.is_distributed = False self.is_distributed = False
self.is_parameter = True
def __str__(self): def __str__(self):
return self._to_readable_code() return self._to_readable_code()
......
...@@ -1432,12 +1432,14 @@ def save_inference_model(dirname, ...@@ -1432,12 +1432,14 @@ def save_inference_model(dirname,
main_program.desc._set_version() main_program.desc._set_version()
paddle.fluid.core.save_op_version_info(main_program.desc) paddle.fluid.core.save_op_version_info(main_program.desc)
with open(model_basename, "wb") as f: with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string()) f.write(main_program._remove_training_info()
.desc.serialize_to_string())
else: else:
# TODO(panyx0718): Save more information so that it can also be used # TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing. # for training and more flexible post-processing.
with open(model_basename + ".main_program", "wb") as f: with open(model_basename + ".main_program", "wb") as f:
f.write(main_program.desc.serialize_to_string()) f.write(main_program._remove_training_info()
.desc.serialize_to_string())
if program_only: if program_only:
warnings.warn( warnings.warn(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import six
import paddle
from paddle import fluid
from paddle import static
paddle.enable_static()
def program_to_IRGraph(program):
graph = fluid.core.Graph(program.desc)
ir_graph = fluid.framework.IrGraph(graph, for_test=False)
return ir_graph
def IRGraph_to_program(ir_graph):
return ir_graph.to_program()
class GraphToProgramPassTest(unittest.TestCase):
def check_vars_equal(self, o_block, c_block):
o_params = sorted(o_block.all_parameters(), key=lambda p: p.name)
c_params = sorted(c_block.all_parameters(), key=lambda p: p.name)
self.assertEqual(len(o_params), len(c_params))
for p_idx in range(len(o_params)):
self.assertEqual(o_params[p_idx].name, c_params[p_idx].name)
o_vars = sorted(o_block.vars.values(), key=lambda v: v.name)
c_vars = sorted(c_block.vars.values(), key=lambda v: v.name)
self.assertEqual(len(o_vars), len(c_vars))
for v_idx in range(len(o_vars)):
self.assertEqual(o_vars[v_idx].name, c_vars[v_idx].name)
def check_op_output_equal(self, o_op, c_op):
self.assertEqual(len(o_op.output_names), len(c_op.output_names))
for out_idx in range(len(o_op.output_names)):
o_out = o_op.output_names[out_idx]
c_out = c_op.output_names[out_idx]
self.assertEqual(o_out, c_out)
self.assertEqual(o_op.output(o_out), c_op.output(c_out))
def check_op_input_equal(self, o_op, c_op):
self.assertEqual(len(o_op.input_names), len(c_op.input_names))
for in_idx in range(len(o_op.input_names)):
o_in = o_op.input_names[in_idx]
c_in = c_op.input_names[in_idx]
self.assertEqual(o_in, c_in)
self.assertEqual(o_op.input(o_in), c_op.input(c_in))
def check_op_attrs_equal(self, o_op, c_op):
o_attrs = sorted(o_op.attr_names)
c_attrs = sorted(c_op.attr_names)
self.assertEqual(len(o_attrs), len(c_attrs))
for attr_idx in range(len(o_attrs)):
o_attr = o_attrs[attr_idx]
c_attr = c_attrs[attr_idx]
self.assertEqual(o_attr, c_attr)
self.assertEqual(
o_op.desc.attr_type(o_attr), c_op.desc.attr_type(c_attr))
class SingleGraphToProgramPass(GraphToProgramPassTest):
def setUp(self):
self.origin_program = self.build_program()
ir_graph = program_to_IRGraph(self.origin_program)
self.converted_program = IRGraph_to_program(ir_graph)
@staticmethod
def build_program():
program = static.Program()
with static.program_guard(program):
data = static.data(name='x', shape=[None, 13], dtype='float32')
hidden = static.nn.fc(data, size=10)
loss = paddle.mean(hidden)
paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)
return program
def test_check_parameter(self):
origin_parameter = sorted(
self.origin_program.all_parameters(), key=lambda p: p.name)
converted_parameter = sorted(
self.converted_program.all_parameters(), key=lambda p: p.name)
self.assertEqual(len(origin_parameter), len(converted_parameter))
for i in range(len(origin_parameter)):
o_para = origin_parameter[i]
c_para = converted_parameter[i]
self.assertEqual(o_para.name, c_para.name)
self.assertEqual(o_para.is_parameter, c_para.is_parameter)
def test_check_stop_gradient(self):
origin_vars = list(self.origin_program.list_vars())
origin_vars = sorted(origin_vars, key=lambda v: v.name)
converted_vars = list(self.converted_program.list_vars())
converted_vars = sorted(converted_vars, key=lambda v: v.name)
self.assertEqual(len(origin_vars), len(converted_vars))
for i in range(len(origin_vars)):
o_var = origin_vars[i]
c_var = converted_vars[i]
self.assertEqual(o_var.name, c_var.name)
self.assertEqual(o_var.stop_gradient, c_var.stop_gradient)
def test_check_ops(self):
o_block = self.origin_program.global_block()
c_block = self.converted_program.global_block()
self.assertEqual(len(o_block.ops), len(c_block.ops))
# ensure op ordering and content same
for i in range(len(o_block.ops)):
o_op = o_block.ops[i]
c_op = c_block.ops[i]
self.assertEqual(o_op.type, c_op.type)
self.check_op_input_equal(o_op, c_op)
self.check_op_output_equal(o_op, c_op)
self.check_op_attrs_equal(o_op, c_op)
'''
#TODO(jiangcheng): Open after PR33949 and PR33949 merged
class MultiBlockGraphToProgramPass(GraphToProgramPassTest):
def setUp(self):
self.origin_program = self.build_program()
ir_graph = program_to_IRGraph(self.origin_program)
self.converted_program = IRGraph_to_program(ir_graph)
@staticmethod
def multiblock_model():
data = static.data(name='t', shape=[None, 10], dtype='float32')
a = static.data(name='a', shape=[10, 1], dtype='int64')
b = static.data(name='b', shape=[10, 1], dtype='int64')
cond = paddle.greater_than(a, b)
ie = fluid.layers.IfElse(cond)
with ie.true_block():
hidden = paddle.nn.functional.relu(data)
ie.output(hidden)
with ie.false_block():
hidden = paddle.nn.functional.softmax(data)
ie.output(hidden)
hidden = ie()
return hidden[0]
@staticmethod
def build_program():
program = static.Program()
with static.program_guard(program):
hidden = MultiBlockGraphToProgramPass.multiblock_model()
loss = paddle.mean(hidden)
paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)
return program
def check_ops_equal(self, o_block, c_block):
o_ops = o_block.ops
c_ops = c_block.ops
self.assertEqual(len(o_ops), len(c_ops))
for op_idx in range(len(o_ops)):
o_op = o_ops[op_idx]
c_op = c_ops[op_idx]
self.assertEqual(o_op.type, c_op.type)
self.check_op_input_equal(o_op, c_op)
self.check_op_output_equal(o_op, c_op)
self.check_op_attrs_equal(o_op, c_op)
def check_block_equal(self, o_block, c_block):
self.check_vars_equal(o_block, c_block)
self.check_ops_equal(o_block, c_block)
def test_check_block(self):
self.assertEqual(self.origin_program.num_blocks,
self.converted_program.num_blocks)
for block_idx in range(self.origin_program.num_blocks):
o_block = self.origin_program.block(block_idx)
c_block = self.converted_program.block(block_idx)
self.assertEqual(o_block.idx, c_block.idx)
self.check_block_equal(o_block, c_block)
'''
if __name__ == "__main__":
unittest.main()
...@@ -162,6 +162,35 @@ class TestProgram(unittest.TestCase): ...@@ -162,6 +162,35 @@ class TestProgram(unittest.TestCase):
self.assertRaises(TypeError, program._copy_dist_param_info_from, self.assertRaises(TypeError, program._copy_dist_param_info_from,
"program") "program")
def test_remove_training_info(self):
def net():
reader = fluid.layers.py_reader(
capacity=10,
shapes=[[-1, 10], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
use_double_buffer=True)
in_data, label = fluid.layers.read_file(reader)
predict_label = fluid.layers.fc(in_data, size=2, act='softmax')
loss = fluid.layers.mean(
fluid.layers.cross_entropy(
input=predict_label, label=label))
optimizer = fluid.optimizer.Adam()
optimizer.minimize(loss)
main_program = fluid.Program()
with fluid.program_guard(main_program):
net()
removed_program = main_program._remove_training_info()
for i in range(removed_program.num_blocks):
block = removed_program.block(i)
for var in block.desc.all_vars():
self.assertFalse(var.has_is_parameter())
self.assertFalse(var.has_stop_gradient())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -510,7 +510,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor, ...@@ -510,7 +510,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
program = _get_valid_program(kwargs.get('program', None)) program = _get_valid_program(kwargs.get('program', None))
program = normalize_program(program, feed_vars, fetch_vars) program = normalize_program(program, feed_vars, fetch_vars)
# serialize and save program # serialize and save program
program_bytes = _serialize_program(program) program_bytes = _serialize_program(program._remove_training_info())
save_to_file(model_path, program_bytes) save_to_file(model_path, program_bytes)
# serialize and save params # serialize and save params
params_bytes = _serialize_persistables(program, executor) params_bytes = _serialize_persistables(program, executor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册