未验证 提交 9b48cfda 编写于 作者: X xiongkun 提交者: GitHub

[cherry-pick][Dy2Stat]Support non-tensor type in input_spec (#33464) #34378

[Dy2Stat]Support non-tensor type in input_spec
上级 dbc54d2d
......@@ -193,14 +193,8 @@ class FunctionSpec(object):
raise TypeError(
"The type(input_spec) should be one of (tuple, list), but received {}.".
format(type_name(input_spec)))
input_spec = tuple(input_spec)
for spec in flatten(input_spec):
if not isinstance(spec, paddle.static.InputSpec):
raise ValueError(
"The type(elem) from input_spec should be `InputSpec`, but received {}.".
format(type_name(spec)))
return input_spec
return tuple(input_spec)
def __repr__(self):
return "function: {}({}), input_spec: {}".format(
......@@ -326,9 +320,8 @@ def convert_to_input_spec(inputs, input_spec):
elif isinstance(input_spec, paddle.static.InputSpec):
return input_spec
else:
raise TypeError(
"The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.".
type_name(input_spec))
# NOTE(Aurelius84): Support non-Tensor type as input spec info
return input_spec
def replace_spec_empty_name(args_name, input_with_spec):
......
......@@ -20,7 +20,6 @@ import inspect
import six
import textwrap
import threading
import warnings
import weakref
from paddle.fluid import framework
......@@ -314,7 +313,7 @@ class StaticFunction(object):
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
# will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to
# display this warning message only once.
warnings.warn(
logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. If you would like to get static graph output, please call API "
"ProgramTranslator.enable(True)")
......@@ -481,6 +480,10 @@ class StaticFunction(object):
# NOTE(chenweihang): we should always translated program based on the `input_spec`
# decorated on forward if it is valid
desired_input_spec = self._function_spec.input_spec
if input_spec is not None:
logging_utils.warn(
"\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n".
format(desired_input_spec, input_spec))
has_input_spec = (desired_input_spec is not None)
if has_input_spec:
......@@ -886,7 +889,7 @@ class ProgramTranslator(object):
if not self.enable_to_static:
# Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
# will show up **only once**.
warnings.warn(
logging_utils.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
......
......@@ -27,6 +27,7 @@ import tempfile
import textwrap
import numpy as np
import paddle
from paddle.fluid import unique_name
from paddle.fluid.data_feeder import convert_dtype
......@@ -141,9 +142,9 @@ def make_hashable(x, error_msg=None):
"""
Makes input `x` hashable.
For some unhashable objects, such as `dict/list/np.ndarray`,applying hash function by using their values.
For some unhashable objects, such as `dict/list/set/np.ndarray`,applying hash function by using their values.
"""
if isinstance(x, (tuple, list)):
if isinstance(x, (tuple, list, set)):
return tuple(map(make_hashable, x))
try:
......@@ -1428,10 +1429,10 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
Returns True if the two input specs are compatible, otherwise False.
args:
src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
desired_input_specs (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
src_input_spec (list or tuple[InputSpec et.al]): list/tuple of
paddle.static.InputSpec or int/str et.al
desired_input_specs (list or tuple[InputSpec et.al]): list/tuple of
paddle.static.InputSpec or int/str et.al
"""
len_specs = len(src_input_specs)
if len_specs != len(desired_input_specs):
......@@ -1440,11 +1441,29 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
for spec in src_input_specs:
if spec not in desired_input_specs:
return False
else:
for i in range(len_specs):
src_shape = src_input_specs[i].shape
other_shape = desired_input_specs[i].shape
for (src_spec, desired_spec) in zip(src_input_specs,
desired_input_specs):
if isinstance(src_spec, paddle.static.InputSpec) or isinstance(
desired_spec, paddle.static.InputSpec):
if not _compatible_tensor_spec(src_spec, desired_spec):
return False
else:
if not _compatible_non_tensor_spec(src_spec, desired_spec):
return False
return True
def _compatible_tensor_spec(src_spec, desired_spec):
"""
Check whether two tensor type spec is compatible.
"""
for spec in [src_spec, desired_spec]:
if not isinstance(spec, paddle.static.InputSpec):
return False
src_shape = src_spec.shape
other_shape = desired_spec.shape
len_shape = len(src_shape)
if len_shape != len(other_shape):
return False
......@@ -1456,14 +1475,35 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
if src_shape[j] != other_shape[j]:
return False
src_dtype = convert_dtype(src_input_specs[i].dtype)
other_dtype = convert_dtype(desired_input_specs[i].dtype)
src_dtype = convert_dtype(src_spec.dtype)
other_dtype = convert_dtype(desired_spec.dtype)
if src_dtype != other_dtype:
return False
return True
def _compatible_non_tensor_spec(src_spec, desired_spec):
"""
Check whether two non-tensor type spec is compatible.
"""
def hash_value(spec):
try:
hash_val = make_hashable(spec)
except:
hash_val = None
return hash_val
src_hash_val = hash_value(src_spec)
desired_hash_val = hash_value(desired_spec)
if src_hash_val != desired_hash_val:
return False
else:
return True
def slice_is_num(slice_node):
# A slice_node.slice can be a:
# (1) ast.Index, which is a simple number such as [1], [-2]
......
......@@ -403,8 +403,15 @@ def _get_input_var_names(inputs, input_spec):
]
if input_spec is None:
# no prune
result_list = input_var_names
elif input_spec is not None and len(input_spec) == len(input_var_names):
return input_var_names
else:
# fileter out non-tensor type spec infos.
input_spec = [
spec for spec in input_spec
if isinstance(spec, paddle.static.InputSpec)
]
if len(input_spec) == len(input_var_names):
# no prune
result_list = input_var_names
# if input spec name not in input_var_names, only raise warning
......@@ -530,8 +537,9 @@ def save(layer, path, input_spec=None, **configs):
Args:
layer (Layer|function): The Layer or function to be saved.
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. If None, all input variables of
input_spec (list or tuple[InputSpec|Tensor|Python built-in variable], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. Moreover, we support to specify non-tensor type argument,
such as int, float, string, or list/dict of them.If None, all input variables of
the original Layer's forward method would be the inputs of the saved model. Default None.
**configs (dict, optional): Other save configuration options for compatibility. We do not
recommend using these configurations, they may be removed in the future. If not necessary,
......@@ -698,9 +706,8 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec.append(
paddle.static.InputSpec.from_tensor(var))
else:
raise TypeError(
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
% type(var))
# NOTE(Aurelius84): Support non-Tensor type in `input_spec`.
inner_input_spec.append(var)
# parse configs
configs = _parse_save_configs(configs)
......@@ -719,7 +726,7 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec)
elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same sturcture
# inner_input_spec is list[InputSpec], it should be packed with same structure
# as original input_spec here.
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
......
......@@ -39,10 +39,6 @@ class TestFunctionSpec(unittest.TestCase):
with self.assertRaises(TypeError):
foo_spec = FunctionSpec(foo_func, input_spec=a_spec)
# each element of input_spec should be `InputSpec`
with self.assertRaises(ValueError):
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, 10])
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
self.assertTrue(len(foo_spec.flat_input_spec) == 2)
......
......@@ -14,9 +14,11 @@
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.static import InputSpec
from paddle.fluid.framework import core, convert_np_dtype_to_dtype_
from paddle.fluid.dygraph.dygraph_to_static.utils import _compatible_non_tensor_spec
class TestInputSpec(unittest.TestCase):
......@@ -30,7 +32,7 @@ class TestInputSpec(unittest.TestCase):
x_bool = fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)
bool_spec = InputSpec.from_tensor(x_bool)
self.assertEqual(bool_spec.dtype, x_bool.dtype)
self.assertEqual(bool_spec.shape, x_bool.shape)
self.assertEqual(list(bool_spec.shape), list(x_bool.shape))
self.assertEqual(bool_spec.name, x_bool.name)
bool_spec2 = InputSpec.from_tensor(x_bool, name='bool_spec')
......@@ -109,5 +111,211 @@ class TestInputSpec(unittest.TestCase):
self.assertTrue(hash(tensor_spec_3) != hash(tensor_spec_4))
class NetWithNonTensorSpec(paddle.nn.Layer):
def __init__(self, in_num, out_num):
super(NetWithNonTensorSpec, self).__init__()
self.linear_1 = paddle.nn.Linear(in_num, out_num)
self.bn_1 = paddle.nn.BatchNorm1D(out_num)
self.linear_2 = paddle.nn.Linear(in_num, out_num)
self.bn_2 = paddle.nn.BatchNorm1D(out_num)
self.linear_3 = paddle.nn.Linear(in_num, out_num)
self.bn_3 = paddle.nn.BatchNorm1D(out_num)
def forward(self, x, bool_v=False, str_v="bn", int_v=1, list_v=None):
x = self.linear_1(x)
if 'bn' in str_v:
x = self.bn_1(x)
if bool_v:
x = self.linear_2(x)
x = self.bn_2(x)
config = {"int_v": int_v, 'other_key': "value"}
if list_v and list_v[-1] > 2:
x = self.linear_3(x)
x = self.another_func(x, config)
out = paddle.mean(x)
return out
def another_func(self, x, config=None):
# config is a dict actually
use_bn = config['int_v'] > 0
x = self.linear_1(x)
if use_bn:
x = self.bn_3(x)
return x
class TestNetWithNonTensorSpec(unittest.TestCase):
def setUp(self):
self.in_num = 16
self.out_num = 16
self.x_spec = paddle.static.InputSpec([-1, 16], name='x')
self.x = paddle.randn([4, 16])
@classmethod
def setUpClass(cls):
paddle.disable_static()
def test_non_tensor_bool(self):
specs = [self.x_spec, False]
self.check_result(specs, 'bool')
def test_non_tensor_str(self):
specs = [self.x_spec, True, "xxx"]
self.check_result(specs, 'str')
def test_non_tensor_int(self):
specs = [self.x_spec, True, "bn", 10]
self.check_result(specs, 'int')
def test_non_tensor_list(self):
specs = [self.x_spec, False, "bn", -10, [4]]
self.check_result(specs, 'list')
def check_result(self, specs, path):
path = './net_non_tensor_' + path
net = NetWithNonTensorSpec(self.in_num, self.out_num)
net.eval()
# dygraph out
dy_out = net(self.x, *specs[1:])
# jit.save directly
paddle.jit.save(net, path + '_direct', input_spec=specs)
load_net = paddle.jit.load(path + '_direct')
load_net.eval()
pred_out = load_net(self.x)
self.assertTrue(np.allclose(dy_out, pred_out))
# @to_static by InputSpec
net = paddle.jit.to_static(net, input_spec=specs)
st_out = net(self.x, *specs[1:])
self.assertTrue(np.allclose(dy_out, st_out))
# jit.save and jit.load
paddle.jit.save(net, path)
load_net = paddle.jit.load(path)
load_net.eval()
load_out = load_net(self.x)
self.assertTrue(np.allclose(st_out, load_out))
def test_spec_compatible(self):
net = NetWithNonTensorSpec(self.in_num, self.out_num)
specs = [self.x_spec, False, "bn", -10]
net = paddle.jit.to_static(net, input_spec=specs)
net.eval()
path = './net_twice'
# NOTE: check input_specs_compatible
new_specs = [self.x_spec, True, "bn", 10]
with self.assertRaises(ValueError):
paddle.jit.save(net, path, input_spec=new_specs)
dy_out = net(self.x)
paddle.jit.save(net, path, [self.x_spec, False, "bn"])
load_net = paddle.jit.load(path)
load_net.eval()
pred_out = load_net(self.x)
self.assertTrue(np.allclose(dy_out, pred_out))
class NetWithNonTensorSpecPrune(paddle.nn.Layer):
def __init__(self, in_num, out_num):
super(NetWithNonTensorSpecPrune, self).__init__()
self.linear_1 = paddle.nn.Linear(in_num, out_num)
self.bn_1 = paddle.nn.BatchNorm1D(out_num)
def forward(self, x, y, use_bn=False):
x = self.linear_1(x)
if use_bn:
x = self.bn_1(x)
out = paddle.mean(x)
if y is not None:
loss = paddle.mean(y) + out
return out, loss
class TestNetWithNonTensorSpecWithPrune(unittest.TestCase):
def setUp(self):
self.in_num = 16
self.out_num = 16
self.x_spec = paddle.static.InputSpec([-1, 16], name='x')
self.y_spec = paddle.static.InputSpec([16], name='y')
self.x = paddle.randn([4, 16])
self.y = paddle.randn([16])
@classmethod
def setUpClass(cls):
paddle.disable_static()
def test_non_tensor_with_prune(self):
specs = [self.x_spec, self.y_spec, True]
path = './net_non_tensor_prune_'
net = NetWithNonTensorSpecPrune(self.in_num, self.out_num)
net.eval()
# dygraph out
dy_out, _ = net(self.x, self.y, *specs[2:])
# jit.save directly
paddle.jit.save(net, path + '_direct', input_spec=specs)
load_net = paddle.jit.load(path + '_direct')
load_net.eval()
pred_out, _ = load_net(self.x, self.y)
self.assertTrue(np.allclose(dy_out, pred_out))
# @to_static by InputSpec
net = paddle.jit.to_static(net, input_spec=specs)
st_out, _ = net(self.x, self.y, *specs[2:])
self.assertTrue(np.allclose(dy_out, st_out))
# jit.save and jit.load with prune y and loss
prune_specs = [self.x_spec, True]
paddle.jit.save(net, path, prune_specs, output_spec=[st_out])
load_net = paddle.jit.load(path)
load_net.eval()
load_out = load_net(self.x) # no y and no loss
self.assertTrue(np.allclose(st_out, load_out))
class UnHashableObject:
def __init__(self, val):
self.val = val
def __hash__(self):
raise TypeError("Unsupported to call hash()")
class TestCompatibleNonTensorSpec(unittest.TestCase):
def test_case(self):
self.assertTrue(_compatible_non_tensor_spec([1, 2, 3], [1, 2, 3]))
self.assertFalse(_compatible_non_tensor_spec([1, 2, 3], [1, 2]))
self.assertFalse(_compatible_non_tensor_spec([1, 2, 3], [1, 3, 2]))
# not supported unhashable object.
self.assertTrue(
_compatible_non_tensor_spec(
UnHashableObject(1), UnHashableObject(1)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册