diff --git a/python/paddle/fluid/dygraph/layer_hooks.py b/python/paddle/fluid/dygraph/layer_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c6867cb7c8ba720637cc98bafb2e016fadbaa3 --- /dev/null +++ b/python/paddle/fluid/dygraph/layer_hooks.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import warnings + +from paddle.fluid.framework import default_main_program, in_dygraph_mode + + +class LayerOpsRecoder: + """ + Record generated operators information in nn.Layer. + """ + + def __init__(self, start=-1, end=-1, ops=None, is_valid=False, hooks=None): + self.start = start + self.end = end + self.ops = ops + self.is_valid = is_valid + self.hooks = hooks + + +def record_program_ops_pre_hook(layer, inputs): + """ + A pre-hook to mark op numbers before enter layer.forward. + """ + if not in_dygraph_mode(): + if layer._op_recorder.start < 0: + layer._op_recorder.start = len(default_main_program().current_block( + ).ops) + layer._op_recorder.is_valid = True + else: + layer._op_recorder.is_valid = False + warnings.warn( + "{} has recorded the op information before. Please check whether you call this layer twice.". + format(layer._full_name)) + + return None + + +def set_op_customized_attrs_post_hook(layer, inputs, outputs): + """ + A post-hook to append customized attributes into all operators generated in current layer. + """ + if not in_dygraph_mode() and layer._op_recorder.is_valid: + + start = layer._op_recorder.start + end = len(default_main_program().current_block().ops) + assert (start >= 0 and end >= start) + ops = default_main_program().current_block().ops[start:end] + + layer._op_recorder.end = end + layer._op_recorder.ops = ops + + for op in ops: + for attr_name, val in six.iteritems(layer._customized_attrs): + op._set_attr(attr_name, val) + + # remove pre-hook and post-hook + for hook_helper in layer._op_recorder.hooks: + hook_helper.remove() + + return None diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index ecf6be1a0224af6b89033d6e279e3c2cfe3ef192..cb7666b353db793ce45ccd23b51cb993313f820b 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -30,6 +30,7 @@ from . import parallel_helper from .. import unique_name from paddle.fluid import core from .layer_object_helper import LayerObjectHelper +from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder from .base import program_desc_tracing_guard, param_guard from paddle.fluid import framework from ..param_attr import ParamAttr @@ -113,6 +114,10 @@ class Layer(core.Layer): self._sub_layers = collections.OrderedDict() self._loaddict_holder = collections.OrderedDict() + # Record generated op_descs in this layer + self._op_recorder = LayerOpsRecoder(ops=[], hooks=[]) + self._customized_attrs = {} + self._forward_pre_hooks = collections.OrderedDict() self._forward_post_hooks = collections.OrderedDict() @@ -665,7 +670,7 @@ class Layer(core.Layer): Parameters: prefix(str, optional): Prefix to prepend to all parameter names. Default: ''. include_self(bool, optional): Whether include the Layer itself. Default: False. - layers_set(set, optioanl): The set to record duplicate sublayers. Default: None. + layers_set(set, optional): The set to record duplicate sublayers. Default: None. Yields: (string, Layer): Tuple of name and Layer @@ -1028,6 +1033,54 @@ class Layer(core.Layer): self._parameters[name] = parameter return parameter + def _set_op_attrs(self, attrs): + """ + Add customized attribute while append_op. In case of quantization, we want to save + some attributes into op_desc while exporting inference model by @to_static. + + Arguments: + attrs(dict): customized attributes that will be added into op_descs. + + NOTE: The interface is only exposed to developers. + """ + + def is_already_registered(is_pre_hook): + layers_hooks = self._forward_pre_hooks if is_pre_hook else self._forward_post_hooks + candidate_hook = record_program_ops_pre_hook if is_pre_hook else set_op_customized_attrs_post_hook + + already_registed = False + if layers_hooks: + last_key = next(reversed(layers_hooks)) + already_registed = (layers_hooks[last_key] == candidate_hook) + + return already_registed + + if not isinstance(attrs, dict): + raise TypeError("attrs should be type(dict), but received {}". + format(type(attrs).__name__)) + + # NOTE: Overwrite behavior for same key. + self._customized_attrs.update(attrs) + + if not is_already_registered(is_pre_hook=True): + pre_hook_helper = self.register_forward_pre_hook( + record_program_ops_pre_hook) + assert len(self._op_recorder.hooks) == 0 + self._op_recorder.hooks = [pre_hook_helper] + + # manually register post_hook to ensure it is inserted into the head. + if not is_already_registered(is_pre_hook=False): + post_hook_helper = self.register_forward_post_hook( + set_op_customized_attrs_post_hook) + if len(self._forward_post_hooks) > 1: + self._forward_post_hooks.move_to_end( + post_hook_helper._hook_id, last=False) + + assert len(self._op_recorder.hooks) == 1 + + # hooks that need to be removed once we finish executing them. + self._op_recorder.hooks.append(post_hook_helper) + def __getstate__(self): return self.__dict__ diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_op_attr.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_op_attr.py new file mode 100644 index 0000000000000000000000000000000000000000..a39b5d7cd1a44b0193e10158b8fbe7de87850fde --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_op_attr.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import paddle +import unittest +import numpy as np + +from paddle.static import InputSpec + + +class MySub(paddle.nn.Layer): + def __init__(self): + super(MySub, self).__init__() + + def forward(self, x, y, name=None): + return paddle.subtract(x, y, name) + + +class NetWithOpAttr(paddle.nn.Layer): + def __init__(self, in_num, out_num): + super(NetWithOpAttr, self).__init__() + + self.linear = paddle.nn.Linear(in_num, out_num) + self.bn = paddle.nn.BatchNorm(out_num) + self.sub = MySub() + + def forward(self, x): + out = self.linear(x) + out = self.sub(out, x) + out = self.bn(out) + return out + + @paddle.jit.to_static(input_spec=[InputSpec([10, 16])]) + def with_cond(self, x): + if paddle.mean(x) > 0.: + out = self.linear(x) + else: + out = self.sub(x, x) + out = self.bn(out) + return out + + +class CheckOpAttr(unittest.TestCase): + def setUp(self): + self.in_num = 16 + self.out_num = 16 + self.x = paddle.randn([10, self.in_num]) + self.expected_results() + + def expected_results(self): + self.fc_attrs = { + "int_val": 10, + "int_vals": [10, 20], + "float_val": 3.8, + "float_vals": [3.8, -0.2] + } + self.bn_attrs = {"bool_val": True, "bool_vals": [True, False]} + self.sub_attrs = {"int_vals": [10, 20], "bool_vals": [True, False]} + + self.infos = { + 'matmul': self.fc_attrs, + 'elementwise_add': self.fc_attrs, + 'batch_norm': self.bn_attrs, + 'tanh': self.bn_attrs, + 'elementwise_sub': self.sub_attrs + } + + def test_set_op_attrs(self): + net = NetWithOpAttr(self.in_num, self.out_num) + # set attrs + net.linear._set_op_attrs(self.fc_attrs) + net.bn._set_op_attrs({"bool_val": False}) # test overwrite behavior + net.bn._set_op_attrs(self.bn_attrs) + net.sub._set_op_attrs(self.sub_attrs) + # assert hooks exist. + self.assertEqual(len(net.linear._forward_pre_hooks), 1) + self.assertEqual(len(net.linear._forward_post_hooks), 1) + # to_static + net = paddle.jit.to_static( + net, input_spec=[InputSpec.from_tensor(self.x)]) + + # assert attrs have be set. + self.check_op_attrs(net.forward.concrete_program.main_program) + + # assert hooks have be clean. + self.assertEqual(len(net.linear._forward_pre_hooks), 0) + self.assertEqual(len(net.linear._forward_post_hooks), 0) + + def check_op_attrs(self, main_program): + for cur_block in main_program.blocks: + ops = cur_block.ops + for op in ops: + if op.type not in self.infos: continue + for attr_name, expect_vals in six.iteritems(self.infos[ + op.type]): + op_vals = op.desc.attr(attr_name) + if not isinstance(expect_vals, list): + expect_vals = [expect_vals] + op_vals = [op_vals] + + for (op_val, expect_val) in zip(op_vals, expect_vals): + if isinstance(op_val, float): + # C++ vs python: 3.799999952316284 ~= 3.8 + self.assertAlmostEqual(op_val, expect_val) + else: + self.assertEqual(op_val, expect_val) + + def test_set_op_attrs_with_sub_block(self): + net = NetWithOpAttr(self.in_num, self.out_num) + # set attrs + net.linear._set_op_attrs({ + "int_vals": [0, 0] + }) # test overwrite behavior + net.linear._set_op_attrs(self.fc_attrs) + net.bn._set_op_attrs(self.bn_attrs) + net.sub._set_op_attrs(self.sub_attrs) + # assert hooks exist. + self.assertEqual(len(net.linear._forward_pre_hooks), 1) + self.assertEqual(len(net.linear._forward_post_hooks), 1) + + # assert attrs have be set. + self.check_op_attrs(net.with_cond.concrete_program.main_program) + + # assert hooks have be clean. + self.assertEqual(len(net.linear._forward_pre_hooks), 0) + self.assertEqual(len(net.linear._forward_post_hooks), 0) + + def test_type_error(self): + net = NetWithOpAttr(self.in_num, self.out_num) + # attrs should be dict + with self.assertRaises(TypeError): + net.linear._set_op_attrs([self.fc_attrs]) + + +if __name__ == '__main__': + unittest.main()