未验证 提交 992d0d93 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat & Quantization]Support append customize attributes into op_desc in nn.Layer (#33359)

* Support append customize attributes into op_desc in nn.Layer

* fix code style

* support override

* add unittest
上级 1382cd22
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import warnings
from paddle.fluid.framework import default_main_program, in_dygraph_mode
class LayerOpsRecoder:
"""
Record generated operators information in nn.Layer.
"""
def __init__(self, start=-1, end=-1, ops=None, is_valid=False, hooks=None):
self.start = start
self.end = end
self.ops = ops
self.is_valid = is_valid
self.hooks = hooks
def record_program_ops_pre_hook(layer, inputs):
"""
A pre-hook to mark op numbers before enter layer.forward.
"""
if not in_dygraph_mode():
if layer._op_recorder.start < 0:
layer._op_recorder.start = len(default_main_program().current_block(
).ops)
layer._op_recorder.is_valid = True
else:
layer._op_recorder.is_valid = False
warnings.warn(
"{} has recorded the op information before. Please check whether you call this layer twice.".
format(layer._full_name))
return None
def set_op_customized_attrs_post_hook(layer, inputs, outputs):
"""
A post-hook to append customized attributes into all operators generated in current layer.
"""
if not in_dygraph_mode() and layer._op_recorder.is_valid:
start = layer._op_recorder.start
end = len(default_main_program().current_block().ops)
assert (start >= 0 and end >= start)
ops = default_main_program().current_block().ops[start:end]
layer._op_recorder.end = end
layer._op_recorder.ops = ops
for op in ops:
for attr_name, val in six.iteritems(layer._customized_attrs):
op._set_attr(attr_name, val)
# remove pre-hook and post-hook
for hook_helper in layer._op_recorder.hooks:
hook_helper.remove()
return None
...@@ -30,6 +30,7 @@ from . import parallel_helper ...@@ -30,6 +30,7 @@ from . import parallel_helper
from .. import unique_name from .. import unique_name
from paddle.fluid import core from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper from .layer_object_helper import LayerObjectHelper
from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder
from .base import program_desc_tracing_guard, param_guard from .base import program_desc_tracing_guard, param_guard
from paddle.fluid import framework from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
...@@ -113,6 +114,10 @@ class Layer(core.Layer): ...@@ -113,6 +114,10 @@ class Layer(core.Layer):
self._sub_layers = collections.OrderedDict() self._sub_layers = collections.OrderedDict()
self._loaddict_holder = collections.OrderedDict() self._loaddict_holder = collections.OrderedDict()
# Record generated op_descs in this layer
self._op_recorder = LayerOpsRecoder(ops=[], hooks=[])
self._customized_attrs = {}
self._forward_pre_hooks = collections.OrderedDict() self._forward_pre_hooks = collections.OrderedDict()
self._forward_post_hooks = collections.OrderedDict() self._forward_post_hooks = collections.OrderedDict()
...@@ -665,7 +670,7 @@ class Layer(core.Layer): ...@@ -665,7 +670,7 @@ class Layer(core.Layer):
Parameters: Parameters:
prefix(str, optional): Prefix to prepend to all parameter names. Default: ''. prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
include_self(bool, optional): Whether include the Layer itself. Default: False. include_self(bool, optional): Whether include the Layer itself. Default: False.
layers_set(set, optioanl): The set to record duplicate sublayers. Default: None. layers_set(set, optional): The set to record duplicate sublayers. Default: None.
Yields: Yields:
(string, Layer): Tuple of name and Layer (string, Layer): Tuple of name and Layer
...@@ -1028,6 +1033,54 @@ class Layer(core.Layer): ...@@ -1028,6 +1033,54 @@ class Layer(core.Layer):
self._parameters[name] = parameter self._parameters[name] = parameter
return parameter return parameter
def _set_op_attrs(self, attrs):
"""
Add customized attribute while append_op. In case of quantization, we want to save
some attributes into op_desc while exporting inference model by @to_static.
Arguments:
attrs(dict): customized attributes that will be added into op_descs.
NOTE: The interface is only exposed to developers.
"""
def is_already_registered(is_pre_hook):
layers_hooks = self._forward_pre_hooks if is_pre_hook else self._forward_post_hooks
candidate_hook = record_program_ops_pre_hook if is_pre_hook else set_op_customized_attrs_post_hook
already_registed = False
if layers_hooks:
last_key = next(reversed(layers_hooks))
already_registed = (layers_hooks[last_key] == candidate_hook)
return already_registed
if not isinstance(attrs, dict):
raise TypeError("attrs should be type(dict), but received {}".
format(type(attrs).__name__))
# NOTE: Overwrite behavior for same key.
self._customized_attrs.update(attrs)
if not is_already_registered(is_pre_hook=True):
pre_hook_helper = self.register_forward_pre_hook(
record_program_ops_pre_hook)
assert len(self._op_recorder.hooks) == 0
self._op_recorder.hooks = [pre_hook_helper]
# manually register post_hook to ensure it is inserted into the head.
if not is_already_registered(is_pre_hook=False):
post_hook_helper = self.register_forward_post_hook(
set_op_customized_attrs_post_hook)
if len(self._forward_post_hooks) > 1:
self._forward_post_hooks.move_to_end(
post_hook_helper._hook_id, last=False)
assert len(self._op_recorder.hooks) == 1
# hooks that need to be removed once we finish executing them.
self._op_recorder.hooks.append(post_hook_helper)
def __getstate__(self): def __getstate__(self):
return self.__dict__ return self.__dict__
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import paddle
import unittest
import numpy as np
from paddle.static import InputSpec
class MySub(paddle.nn.Layer):
def __init__(self):
super(MySub, self).__init__()
def forward(self, x, y, name=None):
return paddle.subtract(x, y, name)
class NetWithOpAttr(paddle.nn.Layer):
def __init__(self, in_num, out_num):
super(NetWithOpAttr, self).__init__()
self.linear = paddle.nn.Linear(in_num, out_num)
self.bn = paddle.nn.BatchNorm(out_num)
self.sub = MySub()
def forward(self, x):
out = self.linear(x)
out = self.sub(out, x)
out = self.bn(out)
return out
@paddle.jit.to_static(input_spec=[InputSpec([10, 16])])
def with_cond(self, x):
if paddle.mean(x) > 0.:
out = self.linear(x)
else:
out = self.sub(x, x)
out = self.bn(out)
return out
class CheckOpAttr(unittest.TestCase):
def setUp(self):
self.in_num = 16
self.out_num = 16
self.x = paddle.randn([10, self.in_num])
self.expected_results()
def expected_results(self):
self.fc_attrs = {
"int_val": 10,
"int_vals": [10, 20],
"float_val": 3.8,
"float_vals": [3.8, -0.2]
}
self.bn_attrs = {"bool_val": True, "bool_vals": [True, False]}
self.sub_attrs = {"int_vals": [10, 20], "bool_vals": [True, False]}
self.infos = {
'matmul': self.fc_attrs,
'elementwise_add': self.fc_attrs,
'batch_norm': self.bn_attrs,
'tanh': self.bn_attrs,
'elementwise_sub': self.sub_attrs
}
def test_set_op_attrs(self):
net = NetWithOpAttr(self.in_num, self.out_num)
# set attrs
net.linear._set_op_attrs(self.fc_attrs)
net.bn._set_op_attrs({"bool_val": False}) # test overwrite behavior
net.bn._set_op_attrs(self.bn_attrs)
net.sub._set_op_attrs(self.sub_attrs)
# assert hooks exist.
self.assertEqual(len(net.linear._forward_pre_hooks), 1)
self.assertEqual(len(net.linear._forward_post_hooks), 1)
# to_static
net = paddle.jit.to_static(
net, input_spec=[InputSpec.from_tensor(self.x)])
# assert attrs have be set.
self.check_op_attrs(net.forward.concrete_program.main_program)
# assert hooks have be clean.
self.assertEqual(len(net.linear._forward_pre_hooks), 0)
self.assertEqual(len(net.linear._forward_post_hooks), 0)
def check_op_attrs(self, main_program):
for cur_block in main_program.blocks:
ops = cur_block.ops
for op in ops:
if op.type not in self.infos: continue
for attr_name, expect_vals in six.iteritems(self.infos[
op.type]):
op_vals = op.desc.attr(attr_name)
if not isinstance(expect_vals, list):
expect_vals = [expect_vals]
op_vals = [op_vals]
for (op_val, expect_val) in zip(op_vals, expect_vals):
if isinstance(op_val, float):
# C++ vs python: 3.799999952316284 ~= 3.8
self.assertAlmostEqual(op_val, expect_val)
else:
self.assertEqual(op_val, expect_val)
def test_set_op_attrs_with_sub_block(self):
net = NetWithOpAttr(self.in_num, self.out_num)
# set attrs
net.linear._set_op_attrs({
"int_vals": [0, 0]
}) # test overwrite behavior
net.linear._set_op_attrs(self.fc_attrs)
net.bn._set_op_attrs(self.bn_attrs)
net.sub._set_op_attrs(self.sub_attrs)
# assert hooks exist.
self.assertEqual(len(net.linear._forward_pre_hooks), 1)
self.assertEqual(len(net.linear._forward_post_hooks), 1)
# assert attrs have be set.
self.check_op_attrs(net.with_cond.concrete_program.main_program)
# assert hooks have be clean.
self.assertEqual(len(net.linear._forward_pre_hooks), 0)
self.assertEqual(len(net.linear._forward_post_hooks), 0)
def test_type_error(self):
net = NetWithOpAttr(self.in_num, self.out_num)
# attrs should be dict
with self.assertRaises(TypeError):
net.linear._set_op_attrs([self.fc_attrs])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册