未验证 提交 63b03cf5 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Support non-tensor type in `input_spec` (#33464)

* support non-tensor type

* fix unittest failed

* add unittest with prune

* rm unused code

* coverage

* fix two or
上级 9d6c8bdf
...@@ -193,14 +193,8 @@ class FunctionSpec(object): ...@@ -193,14 +193,8 @@ class FunctionSpec(object):
raise TypeError( raise TypeError(
"The type(input_spec) should be one of (tuple, list), but received {}.". "The type(input_spec) should be one of (tuple, list), but received {}.".
format(type_name(input_spec))) format(type_name(input_spec)))
input_spec = tuple(input_spec)
for spec in flatten(input_spec):
if not isinstance(spec, paddle.static.InputSpec):
raise ValueError(
"The type(elem) from input_spec should be `InputSpec`, but received {}.".
format(type_name(spec)))
return input_spec return tuple(input_spec)
def __repr__(self): def __repr__(self):
return "function: {}({}), input_spec: {}".format( return "function: {}({}), input_spec: {}".format(
...@@ -326,9 +320,8 @@ def convert_to_input_spec(inputs, input_spec): ...@@ -326,9 +320,8 @@ def convert_to_input_spec(inputs, input_spec):
elif isinstance(input_spec, paddle.static.InputSpec): elif isinstance(input_spec, paddle.static.InputSpec):
return input_spec return input_spec
else: else:
raise TypeError( # NOTE(Aurelius84): Support non-Tensor type as input spec info
"The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.". return input_spec
type_name(input_spec))
def replace_spec_empty_name(args_name, input_with_spec): def replace_spec_empty_name(args_name, input_with_spec):
......
...@@ -20,7 +20,6 @@ import inspect ...@@ -20,7 +20,6 @@ import inspect
import six import six
import textwrap import textwrap
import threading import threading
import warnings
import weakref import weakref
from paddle.fluid import framework from paddle.fluid import framework
...@@ -314,7 +313,7 @@ class StaticFunction(object): ...@@ -314,7 +313,7 @@ class StaticFunction(object):
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message) # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
# will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to # will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to
# display this warning message only once. # display this warning message only once.
warnings.warn( logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. " "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. If you would like to get static graph output, please call API " "We will just return dygraph output. If you would like to get static graph output, please call API "
"ProgramTranslator.enable(True)") "ProgramTranslator.enable(True)")
...@@ -481,6 +480,10 @@ class StaticFunction(object): ...@@ -481,6 +480,10 @@ class StaticFunction(object):
# NOTE(chenweihang): we should always translated program based on the `input_spec` # NOTE(chenweihang): we should always translated program based on the `input_spec`
# decorated on forward if it is valid # decorated on forward if it is valid
desired_input_spec = self._function_spec.input_spec desired_input_spec = self._function_spec.input_spec
if input_spec is not None:
logging_utils.warn(
"\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n".
format(desired_input_spec, input_spec))
has_input_spec = (desired_input_spec is not None) has_input_spec = (desired_input_spec is not None)
if has_input_spec: if has_input_spec:
...@@ -886,7 +889,7 @@ class ProgramTranslator(object): ...@@ -886,7 +889,7 @@ class ProgramTranslator(object):
if not self.enable_to_static: if not self.enable_to_static:
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message) # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
# will show up **only once**. # will show up **only once**.
warnings.warn( logging_utils.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. " "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. " "We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output." "Please call ProgramTranslator.enable(True) if you would like to get static output."
......
...@@ -27,6 +27,7 @@ import tempfile ...@@ -27,6 +27,7 @@ import tempfile
import textwrap import textwrap
import numpy as np import numpy as np
import paddle
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
...@@ -141,9 +142,9 @@ def make_hashable(x, error_msg=None): ...@@ -141,9 +142,9 @@ def make_hashable(x, error_msg=None):
""" """
Makes input `x` hashable. Makes input `x` hashable.
For some unhashable objects, such as `dict/list/np.ndarray`,applying hash function by using their values. For some unhashable objects, such as `dict/list/set/np.ndarray`,applying hash function by using their values.
""" """
if isinstance(x, (tuple, list)): if isinstance(x, (tuple, list, set)):
return tuple(map(make_hashable, x)) return tuple(map(make_hashable, x))
try: try:
...@@ -1421,10 +1422,10 @@ def input_specs_compatible(src_input_specs, desired_input_specs): ...@@ -1421,10 +1422,10 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
Returns True if the two input specs are compatible, otherwise False. Returns True if the two input specs are compatible, otherwise False.
args: args:
src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of src_input_spec (list or tuple[InputSpec et.al]): list/tuple of
paddle.static.InputSpec paddle.static.InputSpec or int/str et.al
desired_input_specs (list[InputSpec]|tuple(InputSpec)): list/tuple of desired_input_specs (list or tuple[InputSpec et.al]): list/tuple of
paddle.static.InputSpec paddle.static.InputSpec or int/str et.al
""" """
len_specs = len(src_input_specs) len_specs = len(src_input_specs)
if len_specs != len(desired_input_specs): if len_specs != len(desired_input_specs):
...@@ -1433,11 +1434,29 @@ def input_specs_compatible(src_input_specs, desired_input_specs): ...@@ -1433,11 +1434,29 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
for spec in src_input_specs: for spec in src_input_specs:
if spec not in desired_input_specs: if spec not in desired_input_specs:
return False return False
else: else:
for i in range(len_specs): for (src_spec, desired_spec) in zip(src_input_specs,
src_shape = src_input_specs[i].shape desired_input_specs):
other_shape = desired_input_specs[i].shape if isinstance(src_spec, paddle.static.InputSpec) or isinstance(
desired_spec, paddle.static.InputSpec):
if not _compatible_tensor_spec(src_spec, desired_spec):
return False
else:
if not _compatible_non_tensor_spec(src_spec, desired_spec):
return False
return True
def _compatible_tensor_spec(src_spec, desired_spec):
"""
Check whether two tensor type spec is compatible.
"""
for spec in [src_spec, desired_spec]:
if not isinstance(spec, paddle.static.InputSpec):
return False
src_shape = src_spec.shape
other_shape = desired_spec.shape
len_shape = len(src_shape) len_shape = len(src_shape)
if len_shape != len(other_shape): if len_shape != len(other_shape):
return False return False
...@@ -1449,14 +1468,35 @@ def input_specs_compatible(src_input_specs, desired_input_specs): ...@@ -1449,14 +1468,35 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
if src_shape[j] != other_shape[j]: if src_shape[j] != other_shape[j]:
return False return False
src_dtype = convert_dtype(src_input_specs[i].dtype) src_dtype = convert_dtype(src_spec.dtype)
other_dtype = convert_dtype(desired_input_specs[i].dtype) other_dtype = convert_dtype(desired_spec.dtype)
if src_dtype != other_dtype: if src_dtype != other_dtype:
return False return False
return True return True
def _compatible_non_tensor_spec(src_spec, desired_spec):
"""
Check whether two non-tensor type spec is compatible.
"""
def hash_value(spec):
try:
hash_val = make_hashable(spec)
except:
hash_val = None
return hash_val
src_hash_val = hash_value(src_spec)
desired_hash_val = hash_value(desired_spec)
if src_hash_val != desired_hash_val:
return False
else:
return True
def slice_is_num(slice_node): def slice_is_num(slice_node):
# A slice_node.slice can be a: # A slice_node.slice can be a:
# (1) ast.Index, which is a simple number such as [1], [-2] # (1) ast.Index, which is a simple number such as [1], [-2]
......
...@@ -403,8 +403,15 @@ def _get_input_var_names(inputs, input_spec): ...@@ -403,8 +403,15 @@ def _get_input_var_names(inputs, input_spec):
] ]
if input_spec is None: if input_spec is None:
# no prune # no prune
result_list = input_var_names return input_var_names
elif input_spec is not None and len(input_spec) == len(input_var_names): else:
# fileter out non-tensor type spec infos.
input_spec = [
spec for spec in input_spec
if isinstance(spec, paddle.static.InputSpec)
]
if len(input_spec) == len(input_var_names):
# no prune # no prune
result_list = input_var_names result_list = input_var_names
# if input spec name not in input_var_names, only raise warning # if input spec name not in input_var_names, only raise warning
...@@ -530,8 +537,9 @@ def save(layer, path, input_spec=None, **configs): ...@@ -530,8 +537,9 @@ def save(layer, path, input_spec=None, **configs):
Args: Args:
layer (Layer|function): The Layer or function to be saved. layer (Layer|function): The Layer or function to be saved.
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``. path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward input_spec (list or tuple[InputSpec|Tensor|Python built-in variable], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. If None, all input variables of method, which can be described by InputSpec or example Tensor. Moreover, we support to specify non-tensor type argument,
such as int, float, string, or list/dict of them.If None, all input variables of
the original Layer's forward method would be the inputs of the saved model. Default None. the original Layer's forward method would be the inputs of the saved model. Default None.
**configs (dict, optional): Other save configuration options for compatibility. We do not **configs (dict, optional): Other save configuration options for compatibility. We do not
recommend using these configurations, they may be removed in the future. If not necessary, recommend using these configurations, they may be removed in the future. If not necessary,
...@@ -698,9 +706,8 @@ def save(layer, path, input_spec=None, **configs): ...@@ -698,9 +706,8 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec.append( inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var)) paddle.static.InputSpec.from_tensor(var))
else: else:
raise TypeError( # NOTE(Aurelius84): Support non-Tensor type in `input_spec`.
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s." inner_input_spec.append(var)
% type(var))
# parse configs # parse configs
configs = _parse_save_configs(configs) configs = _parse_save_configs(configs)
...@@ -719,7 +726,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -719,7 +726,7 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec) inner_input_spec)
elif 'forward' == attr_func: elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error # transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same sturcture # inner_input_spec is list[InputSpec], it should be packed with same structure
# as original input_spec here. # as original input_spec here.
if inner_input_spec: if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec, inner_input_spec = pack_sequence_as(input_spec,
......
...@@ -39,10 +39,6 @@ class TestFunctionSpec(unittest.TestCase): ...@@ -39,10 +39,6 @@ class TestFunctionSpec(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
foo_spec = FunctionSpec(foo_func, input_spec=a_spec) foo_spec = FunctionSpec(foo_func, input_spec=a_spec)
# each element of input_spec should be `InputSpec`
with self.assertRaises(ValueError):
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, 10])
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec]) foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
self.assertTrue(len(foo_spec.flat_input_spec) == 2) self.assertTrue(len(foo_spec.flat_input_spec) == 2)
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid.framework import core, convert_np_dtype_to_dtype_ from paddle.fluid.framework import core, convert_np_dtype_to_dtype_
from paddle.fluid.dygraph.dygraph_to_static.utils import _compatible_non_tensor_spec
class TestInputSpec(unittest.TestCase): class TestInputSpec(unittest.TestCase):
...@@ -30,7 +32,7 @@ class TestInputSpec(unittest.TestCase): ...@@ -30,7 +32,7 @@ class TestInputSpec(unittest.TestCase):
x_bool = fluid.layers.fill_constant(shape=[1], dtype='bool', value=True) x_bool = fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)
bool_spec = InputSpec.from_tensor(x_bool) bool_spec = InputSpec.from_tensor(x_bool)
self.assertEqual(bool_spec.dtype, x_bool.dtype) self.assertEqual(bool_spec.dtype, x_bool.dtype)
self.assertEqual(bool_spec.shape, x_bool.shape) self.assertEqual(list(bool_spec.shape), list(x_bool.shape))
self.assertEqual(bool_spec.name, x_bool.name) self.assertEqual(bool_spec.name, x_bool.name)
bool_spec2 = InputSpec.from_tensor(x_bool, name='bool_spec') bool_spec2 = InputSpec.from_tensor(x_bool, name='bool_spec')
...@@ -109,5 +111,211 @@ class TestInputSpec(unittest.TestCase): ...@@ -109,5 +111,211 @@ class TestInputSpec(unittest.TestCase):
self.assertTrue(hash(tensor_spec_3) != hash(tensor_spec_4)) self.assertTrue(hash(tensor_spec_3) != hash(tensor_spec_4))
class NetWithNonTensorSpec(paddle.nn.Layer):
def __init__(self, in_num, out_num):
super(NetWithNonTensorSpec, self).__init__()
self.linear_1 = paddle.nn.Linear(in_num, out_num)
self.bn_1 = paddle.nn.BatchNorm1D(out_num)
self.linear_2 = paddle.nn.Linear(in_num, out_num)
self.bn_2 = paddle.nn.BatchNorm1D(out_num)
self.linear_3 = paddle.nn.Linear(in_num, out_num)
self.bn_3 = paddle.nn.BatchNorm1D(out_num)
def forward(self, x, bool_v=False, str_v="bn", int_v=1, list_v=None):
x = self.linear_1(x)
if 'bn' in str_v:
x = self.bn_1(x)
if bool_v:
x = self.linear_2(x)
x = self.bn_2(x)
config = {"int_v": int_v, 'other_key': "value"}
if list_v and list_v[-1] > 2:
x = self.linear_3(x)
x = self.another_func(x, config)
out = paddle.mean(x)
return out
def another_func(self, x, config=None):
# config is a dict actually
use_bn = config['int_v'] > 0
x = self.linear_1(x)
if use_bn:
x = self.bn_3(x)
return x
class TestNetWithNonTensorSpec(unittest.TestCase):
def setUp(self):
self.in_num = 16
self.out_num = 16
self.x_spec = paddle.static.InputSpec([-1, 16], name='x')
self.x = paddle.randn([4, 16])
@classmethod
def setUpClass(cls):
paddle.disable_static()
def test_non_tensor_bool(self):
specs = [self.x_spec, False]
self.check_result(specs, 'bool')
def test_non_tensor_str(self):
specs = [self.x_spec, True, "xxx"]
self.check_result(specs, 'str')
def test_non_tensor_int(self):
specs = [self.x_spec, True, "bn", 10]
self.check_result(specs, 'int')
def test_non_tensor_list(self):
specs = [self.x_spec, False, "bn", -10, [4]]
self.check_result(specs, 'list')
def check_result(self, specs, path):
path = './net_non_tensor_' + path
net = NetWithNonTensorSpec(self.in_num, self.out_num)
net.eval()
# dygraph out
dy_out = net(self.x, *specs[1:])
# jit.save directly
paddle.jit.save(net, path + '_direct', input_spec=specs)
load_net = paddle.jit.load(path + '_direct')
load_net.eval()
pred_out = load_net(self.x)
self.assertTrue(np.allclose(dy_out, pred_out))
# @to_static by InputSpec
net = paddle.jit.to_static(net, input_spec=specs)
st_out = net(self.x, *specs[1:])
self.assertTrue(np.allclose(dy_out, st_out))
# jit.save and jit.load
paddle.jit.save(net, path)
load_net = paddle.jit.load(path)
load_net.eval()
load_out = load_net(self.x)
self.assertTrue(np.allclose(st_out, load_out))
def test_spec_compatible(self):
net = NetWithNonTensorSpec(self.in_num, self.out_num)
specs = [self.x_spec, False, "bn", -10]
net = paddle.jit.to_static(net, input_spec=specs)
net.eval()
path = './net_twice'
# NOTE: check input_specs_compatible
new_specs = [self.x_spec, True, "bn", 10]
with self.assertRaises(ValueError):
paddle.jit.save(net, path, input_spec=new_specs)
dy_out = net(self.x)
paddle.jit.save(net, path, [self.x_spec, False, "bn"])
load_net = paddle.jit.load(path)
load_net.eval()
pred_out = load_net(self.x)
self.assertTrue(np.allclose(dy_out, pred_out))
class NetWithNonTensorSpecPrune(paddle.nn.Layer):
def __init__(self, in_num, out_num):
super(NetWithNonTensorSpecPrune, self).__init__()
self.linear_1 = paddle.nn.Linear(in_num, out_num)
self.bn_1 = paddle.nn.BatchNorm1D(out_num)
def forward(self, x, y, use_bn=False):
x = self.linear_1(x)
if use_bn:
x = self.bn_1(x)
out = paddle.mean(x)
if y is not None:
loss = paddle.mean(y) + out
return out, loss
class TestNetWithNonTensorSpecWithPrune(unittest.TestCase):
def setUp(self):
self.in_num = 16
self.out_num = 16
self.x_spec = paddle.static.InputSpec([-1, 16], name='x')
self.y_spec = paddle.static.InputSpec([16], name='y')
self.x = paddle.randn([4, 16])
self.y = paddle.randn([16])
@classmethod
def setUpClass(cls):
paddle.disable_static()
def test_non_tensor_with_prune(self):
specs = [self.x_spec, self.y_spec, True]
path = './net_non_tensor_prune_'
net = NetWithNonTensorSpecPrune(self.in_num, self.out_num)
net.eval()
# dygraph out
dy_out, _ = net(self.x, self.y, *specs[2:])
# jit.save directly
paddle.jit.save(net, path + '_direct', input_spec=specs)
load_net = paddle.jit.load(path + '_direct')
load_net.eval()
pred_out, _ = load_net(self.x, self.y)
self.assertTrue(np.allclose(dy_out, pred_out))
# @to_static by InputSpec
net = paddle.jit.to_static(net, input_spec=specs)
st_out, _ = net(self.x, self.y, *specs[2:])
self.assertTrue(np.allclose(dy_out, st_out))
# jit.save and jit.load with prune y and loss
prune_specs = [self.x_spec, True]
paddle.jit.save(net, path, prune_specs, output_spec=[st_out])
load_net = paddle.jit.load(path)
load_net.eval()
load_out = load_net(self.x) # no y and no loss
self.assertTrue(np.allclose(st_out, load_out))
class UnHashableObject:
def __init__(self, val):
self.val = val
def __hash__(self):
raise TypeError("Unsupported to call hash()")
class TestCompatibleNonTensorSpec(unittest.TestCase):
def test_case(self):
self.assertTrue(_compatible_non_tensor_spec([1, 2, 3], [1, 2, 3]))
self.assertFalse(_compatible_non_tensor_spec([1, 2, 3], [1, 2]))
self.assertFalse(_compatible_non_tensor_spec([1, 2, 3], [1, 3, 2]))
# not supported unhashable object.
self.assertTrue(
_compatible_non_tensor_spec(
UnHashableObject(1), UnHashableObject(1)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册