diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py index 205766e4613423f41c542ba3c48f2d5d1db0fb02..031351ca118ef5185b82db754915c80c2f069de0 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -193,14 +193,8 @@ class FunctionSpec(object): raise TypeError( "The type(input_spec) should be one of (tuple, list), but received {}.". format(type_name(input_spec))) - input_spec = tuple(input_spec) - for spec in flatten(input_spec): - if not isinstance(spec, paddle.static.InputSpec): - raise ValueError( - "The type(elem) from input_spec should be `InputSpec`, but received {}.". - format(type_name(spec))) - return input_spec + return tuple(input_spec) def __repr__(self): return "function: {}({}), input_spec: {}".format( @@ -326,9 +320,8 @@ def convert_to_input_spec(inputs, input_spec): elif isinstance(input_spec, paddle.static.InputSpec): return input_spec else: - raise TypeError( - "The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.". - type_name(input_spec)) + # NOTE(Aurelius84): Support non-Tensor type as input spec info + return input_spec def replace_spec_empty_name(args_name, input_with_spec): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 770a72fbaf004865788fe949607dea6faa7a7930..4532c65e74bd21dde769304105ed32ba305d2a47 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -20,7 +20,6 @@ import inspect import six import textwrap import threading -import warnings import weakref from paddle.fluid import framework @@ -314,7 +313,7 @@ class StaticFunction(object): # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message) # will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to # display this warning message only once. - warnings.warn( + logging_utils.warn( "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. " "We will just return dygraph output. If you would like to get static graph output, please call API " "ProgramTranslator.enable(True)") @@ -481,6 +480,10 @@ class StaticFunction(object): # NOTE(chenweihang): we should always translated program based on the `input_spec` # decorated on forward if it is valid desired_input_spec = self._function_spec.input_spec + if input_spec is not None: + logging_utils.warn( + "\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n". + format(desired_input_spec, input_spec)) has_input_spec = (desired_input_spec is not None) if has_input_spec: @@ -886,7 +889,7 @@ class ProgramTranslator(object): if not self.enable_to_static: # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message) # will show up **only once**. - warnings.warn( + logging_utils.warn( "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. " "We will just return dygraph output. " "Please call ProgramTranslator.enable(True) if you would like to get static output." diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 001116a74c9cc5f149de8ab1ebd7f8f5c2f68068..f27501d1c35a347bac807aee8c21de43ce7e79e7 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -27,6 +27,7 @@ import tempfile import textwrap import numpy as np +import paddle from paddle.fluid import unique_name from paddle.fluid.data_feeder import convert_dtype @@ -141,9 +142,9 @@ def make_hashable(x, error_msg=None): """ Makes input `x` hashable. - For some unhashable objects, such as `dict/list/np.ndarray`,applying hash function by using their values. + For some unhashable objects, such as `dict/list/set/np.ndarray`,applying hash function by using their values. """ - if isinstance(x, (tuple, list)): + if isinstance(x, (tuple, list, set)): return tuple(map(make_hashable, x)) try: @@ -1421,10 +1422,10 @@ def input_specs_compatible(src_input_specs, desired_input_specs): Returns True if the two input specs are compatible, otherwise False. args: - src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of - paddle.static.InputSpec - desired_input_specs (list[InputSpec]|tuple(InputSpec)): list/tuple of - paddle.static.InputSpec + src_input_spec (list or tuple[InputSpec et.al]): list/tuple of + paddle.static.InputSpec or int/str et.al + desired_input_specs (list or tuple[InputSpec et.al]): list/tuple of + paddle.static.InputSpec or int/str et.al """ len_specs = len(src_input_specs) if len_specs != len(desired_input_specs): @@ -1433,30 +1434,69 @@ def input_specs_compatible(src_input_specs, desired_input_specs): for spec in src_input_specs: if spec not in desired_input_specs: return False - else: - for i in range(len_specs): - src_shape = src_input_specs[i].shape - other_shape = desired_input_specs[i].shape - len_shape = len(src_shape) - if len_shape != len(other_shape): - return False - for j in range(len_shape): - if src_shape[j] is None or src_shape[j] < 0: - continue - if other_shape[j] is None or other_shape[j] < 0: - continue - if src_shape[j] != other_shape[j]: + for (src_spec, desired_spec) in zip(src_input_specs, + desired_input_specs): + if isinstance(src_spec, paddle.static.InputSpec) or isinstance( + desired_spec, paddle.static.InputSpec): + if not _compatible_tensor_spec(src_spec, desired_spec): + return False + else: + if not _compatible_non_tensor_spec(src_spec, desired_spec): return False - src_dtype = convert_dtype(src_input_specs[i].dtype) - other_dtype = convert_dtype(desired_input_specs[i].dtype) - if src_dtype != other_dtype: - return False + return True + + +def _compatible_tensor_spec(src_spec, desired_spec): + """ + Check whether two tensor type spec is compatible. + """ + for spec in [src_spec, desired_spec]: + if not isinstance(spec, paddle.static.InputSpec): + return False + src_shape = src_spec.shape + other_shape = desired_spec.shape + len_shape = len(src_shape) + if len_shape != len(other_shape): + return False + for j in range(len_shape): + if src_shape[j] is None or src_shape[j] < 0: + continue + if other_shape[j] is None or other_shape[j] < 0: + continue + if src_shape[j] != other_shape[j]: + return False + + src_dtype = convert_dtype(src_spec.dtype) + other_dtype = convert_dtype(desired_spec.dtype) + if src_dtype != other_dtype: + return False return True +def _compatible_non_tensor_spec(src_spec, desired_spec): + """ + Check whether two non-tensor type spec is compatible. + """ + + def hash_value(spec): + try: + hash_val = make_hashable(spec) + except: + hash_val = None + return hash_val + + src_hash_val = hash_value(src_spec) + desired_hash_val = hash_value(desired_spec) + + if src_hash_val != desired_hash_val: + return False + else: + return True + + def slice_is_num(slice_node): # A slice_node.slice can be a: # (1) ast.Index, which is a simple number such as [1], [-2] diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 352a377fa3adc557213abcc8f4919e0d30cae97a..3401f85a78b074f9e195b23828f56a7b3f848ddc 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -403,8 +403,15 @@ def _get_input_var_names(inputs, input_spec): ] if input_spec is None: # no prune - result_list = input_var_names - elif input_spec is not None and len(input_spec) == len(input_var_names): + return input_var_names + else: + # fileter out non-tensor type spec infos. + input_spec = [ + spec for spec in input_spec + if isinstance(spec, paddle.static.InputSpec) + ] + + if len(input_spec) == len(input_var_names): # no prune result_list = input_var_names # if input spec name not in input_var_names, only raise warning @@ -530,8 +537,9 @@ def save(layer, path, input_spec=None, **configs): Args: layer (Layer|function): The Layer or function to be saved. path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``. - input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward - method, which can be described by InputSpec or example Tensor. If None, all input variables of + input_spec (list or tuple[InputSpec|Tensor|Python built-in variable], optional): Describes the input of the saved model's forward + method, which can be described by InputSpec or example Tensor. Moreover, we support to specify non-tensor type argument, + such as int, float, string, or list/dict of them.If None, all input variables of the original Layer's forward method would be the inputs of the saved model. Default None. **configs (dict, optional): Other save configuration options for compatibility. We do not recommend using these configurations, they may be removed in the future. If not necessary, @@ -698,9 +706,8 @@ def save(layer, path, input_spec=None, **configs): inner_input_spec.append( paddle.static.InputSpec.from_tensor(var)) else: - raise TypeError( - "The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s." - % type(var)) + # NOTE(Aurelius84): Support non-Tensor type in `input_spec`. + inner_input_spec.append(var) # parse configs configs = _parse_save_configs(configs) @@ -719,7 +726,7 @@ def save(layer, path, input_spec=None, **configs): inner_input_spec) elif 'forward' == attr_func: # transform in jit.save, if input_spec is incomplete, declarative will throw error - # inner_input_spec is list[InputSpec], it should be packed with same sturcture + # inner_input_spec is list[InputSpec], it should be packed with same structure # as original input_spec here. if inner_input_spec: inner_input_spec = pack_sequence_as(input_spec, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py index 9dc8c12f24575b293077829be35a8f9c3605c290..c242bb34626c1ca91770e557d6d68646c09f9618 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py @@ -39,10 +39,6 @@ class TestFunctionSpec(unittest.TestCase): with self.assertRaises(TypeError): foo_spec = FunctionSpec(foo_func, input_spec=a_spec) - # each element of input_spec should be `InputSpec` - with self.assertRaises(ValueError): - foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, 10]) - foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec]) self.assertTrue(len(foo_spec.flat_input_spec) == 2) diff --git a/python/paddle/fluid/tests/unittests/test_input_spec.py b/python/paddle/fluid/tests/unittests/test_input_spec.py index e329a37488a2cb8234532cd0a9beb7a1a25e72a6..4e0aa4a9bcad7d859a1b3ad9d51b9a530f70dd70 100644 --- a/python/paddle/fluid/tests/unittests/test_input_spec.py +++ b/python/paddle/fluid/tests/unittests/test_input_spec.py @@ -14,9 +14,11 @@ import unittest import numpy as np +import paddle import paddle.fluid as fluid from paddle.static import InputSpec from paddle.fluid.framework import core, convert_np_dtype_to_dtype_ +from paddle.fluid.dygraph.dygraph_to_static.utils import _compatible_non_tensor_spec class TestInputSpec(unittest.TestCase): @@ -30,7 +32,7 @@ class TestInputSpec(unittest.TestCase): x_bool = fluid.layers.fill_constant(shape=[1], dtype='bool', value=True) bool_spec = InputSpec.from_tensor(x_bool) self.assertEqual(bool_spec.dtype, x_bool.dtype) - self.assertEqual(bool_spec.shape, x_bool.shape) + self.assertEqual(list(bool_spec.shape), list(x_bool.shape)) self.assertEqual(bool_spec.name, x_bool.name) bool_spec2 = InputSpec.from_tensor(x_bool, name='bool_spec') @@ -109,5 +111,211 @@ class TestInputSpec(unittest.TestCase): self.assertTrue(hash(tensor_spec_3) != hash(tensor_spec_4)) +class NetWithNonTensorSpec(paddle.nn.Layer): + def __init__(self, in_num, out_num): + super(NetWithNonTensorSpec, self).__init__() + self.linear_1 = paddle.nn.Linear(in_num, out_num) + self.bn_1 = paddle.nn.BatchNorm1D(out_num) + + self.linear_2 = paddle.nn.Linear(in_num, out_num) + self.bn_2 = paddle.nn.BatchNorm1D(out_num) + + self.linear_3 = paddle.nn.Linear(in_num, out_num) + self.bn_3 = paddle.nn.BatchNorm1D(out_num) + + def forward(self, x, bool_v=False, str_v="bn", int_v=1, list_v=None): + x = self.linear_1(x) + if 'bn' in str_v: + x = self.bn_1(x) + + if bool_v: + x = self.linear_2(x) + x = self.bn_2(x) + + config = {"int_v": int_v, 'other_key': "value"} + if list_v and list_v[-1] > 2: + x = self.linear_3(x) + x = self.another_func(x, config) + + out = paddle.mean(x) + return out + + def another_func(self, x, config=None): + # config is a dict actually + use_bn = config['int_v'] > 0 + + x = self.linear_1(x) + if use_bn: + x = self.bn_3(x) + + return x + + +class TestNetWithNonTensorSpec(unittest.TestCase): + def setUp(self): + self.in_num = 16 + self.out_num = 16 + self.x_spec = paddle.static.InputSpec([-1, 16], name='x') + self.x = paddle.randn([4, 16]) + + @classmethod + def setUpClass(cls): + paddle.disable_static() + + def test_non_tensor_bool(self): + specs = [self.x_spec, False] + self.check_result(specs, 'bool') + + def test_non_tensor_str(self): + specs = [self.x_spec, True, "xxx"] + self.check_result(specs, 'str') + + def test_non_tensor_int(self): + specs = [self.x_spec, True, "bn", 10] + self.check_result(specs, 'int') + + def test_non_tensor_list(self): + specs = [self.x_spec, False, "bn", -10, [4]] + self.check_result(specs, 'list') + + def check_result(self, specs, path): + path = './net_non_tensor_' + path + + net = NetWithNonTensorSpec(self.in_num, self.out_num) + net.eval() + # dygraph out + dy_out = net(self.x, *specs[1:]) + + # jit.save directly + paddle.jit.save(net, path + '_direct', input_spec=specs) + load_net = paddle.jit.load(path + '_direct') + load_net.eval() + pred_out = load_net(self.x) + + self.assertTrue(np.allclose(dy_out, pred_out)) + + # @to_static by InputSpec + net = paddle.jit.to_static(net, input_spec=specs) + st_out = net(self.x, *specs[1:]) + + self.assertTrue(np.allclose(dy_out, st_out)) + + # jit.save and jit.load + paddle.jit.save(net, path) + load_net = paddle.jit.load(path) + load_net.eval() + load_out = load_net(self.x) + + self.assertTrue(np.allclose(st_out, load_out)) + + def test_spec_compatible(self): + net = NetWithNonTensorSpec(self.in_num, self.out_num) + + specs = [self.x_spec, False, "bn", -10] + net = paddle.jit.to_static(net, input_spec=specs) + net.eval() + + path = './net_twice' + + # NOTE: check input_specs_compatible + new_specs = [self.x_spec, True, "bn", 10] + with self.assertRaises(ValueError): + paddle.jit.save(net, path, input_spec=new_specs) + + dy_out = net(self.x) + + paddle.jit.save(net, path, [self.x_spec, False, "bn"]) + load_net = paddle.jit.load(path) + load_net.eval() + pred_out = load_net(self.x) + + self.assertTrue(np.allclose(dy_out, pred_out)) + + +class NetWithNonTensorSpecPrune(paddle.nn.Layer): + def __init__(self, in_num, out_num): + super(NetWithNonTensorSpecPrune, self).__init__() + self.linear_1 = paddle.nn.Linear(in_num, out_num) + self.bn_1 = paddle.nn.BatchNorm1D(out_num) + + def forward(self, x, y, use_bn=False): + x = self.linear_1(x) + if use_bn: + x = self.bn_1(x) + + out = paddle.mean(x) + + if y is not None: + loss = paddle.mean(y) + out + + return out, loss + + +class TestNetWithNonTensorSpecWithPrune(unittest.TestCase): + def setUp(self): + self.in_num = 16 + self.out_num = 16 + self.x_spec = paddle.static.InputSpec([-1, 16], name='x') + self.y_spec = paddle.static.InputSpec([16], name='y') + self.x = paddle.randn([4, 16]) + self.y = paddle.randn([16]) + + @classmethod + def setUpClass(cls): + paddle.disable_static() + + def test_non_tensor_with_prune(self): + specs = [self.x_spec, self.y_spec, True] + path = './net_non_tensor_prune_' + + net = NetWithNonTensorSpecPrune(self.in_num, self.out_num) + net.eval() + # dygraph out + dy_out, _ = net(self.x, self.y, *specs[2:]) + + # jit.save directly + paddle.jit.save(net, path + '_direct', input_spec=specs) + load_net = paddle.jit.load(path + '_direct') + load_net.eval() + pred_out, _ = load_net(self.x, self.y) + + self.assertTrue(np.allclose(dy_out, pred_out)) + + # @to_static by InputSpec + net = paddle.jit.to_static(net, input_spec=specs) + st_out, _ = net(self.x, self.y, *specs[2:]) + + self.assertTrue(np.allclose(dy_out, st_out)) + + # jit.save and jit.load with prune y and loss + prune_specs = [self.x_spec, True] + paddle.jit.save(net, path, prune_specs, output_spec=[st_out]) + load_net = paddle.jit.load(path) + load_net.eval() + load_out = load_net(self.x) # no y and no loss + + self.assertTrue(np.allclose(st_out, load_out)) + + +class UnHashableObject: + def __init__(self, val): + self.val = val + + def __hash__(self): + raise TypeError("Unsupported to call hash()") + + +class TestCompatibleNonTensorSpec(unittest.TestCase): + def test_case(self): + self.assertTrue(_compatible_non_tensor_spec([1, 2, 3], [1, 2, 3])) + self.assertFalse(_compatible_non_tensor_spec([1, 2, 3], [1, 2])) + self.assertFalse(_compatible_non_tensor_spec([1, 2, 3], [1, 3, 2])) + + # not supported unhashable object. + self.assertTrue( + _compatible_non_tensor_spec( + UnHashableObject(1), UnHashableObject(1))) + + if __name__ == '__main__': unittest.main()