From f05613683f187eda2ed4542bce8e0767092c84c0 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 27 Aug 2020 18:20:05 +0800 Subject: [PATCH] [Dy2stat] Support InputSpec and Return callable class instance in @declarative (#25960) * add InputSpec * add unittest for tensorSpec and SimpleNet --- .../slim/quantization/imperative/qat.py | 18 +- .../dygraph_to_static/convert_call_func.py | 48 +- .../dygraph_to_static/function_spec.py | 311 +++++++++++ .../dygraph_to_static/program_translator.py | 484 ++++++++++++++---- .../fluid/dygraph/dygraph_to_static/utils.py | 73 +++ python/paddle/fluid/dygraph/io.py | 10 +- python/paddle/fluid/dygraph/jit.py | 101 ++-- .../dygraph_to_static/test_declarative.py | 250 +++++++++ .../dygraph_to_static/test_function_spec.py | 116 +++++ .../dygraph_to_static/test_partial_program.py | 2 +- .../test_save_inference_model.py | 2 +- .../fluid/tests/unittests/test_input_spec.py | 113 ++++ .../tests/unittests/test_jit_save_load.py | 94 +++- python/paddle/incubate/hapi/model.py | 14 +- python/paddle/static/input.py | 209 +++++++- 15 files changed, 1637 insertions(+), 208 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py create mode 100644 python/paddle/fluid/tests/unittests/test_input_spec.py diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index cece2ba4a3d..e3755cbafea 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -15,6 +15,7 @@ import logging import numpy as np import sys +import paddle from paddle.fluid import dygraph from paddle.fluid.dygraph.nn import Conv2D from paddle.fluid.dygraph.nn import Linear @@ -195,13 +196,16 @@ class ImperativeQuantAware(object): with dygraph.guard(): model.eval() input_vars = [] - for shape, dtype in zip(input_shape, input_dtype): - raw_data = np.random.random(shape) - input_data = raw_data[np.newaxis, :].astype( - dtype) if append_batch_size else raw_data.astype(dtype) - input_var = dygraph.to_variable(input_data) - input_vars.append(input_var) - outputs = prog_trans.get_output(model.forward, model, *input_vars) + for i, (shape, dtype) in enumerate(zip(input_shape, input_dtype)): + if append_batch_size: + shape = [None] + list(shape) + # Note(Aurelius84): need a elegant way to name this. + in_spec = paddle.static.InputSpec(shape, dtype, 'feed_%d' % i) + input_vars.append(in_spec) + # use `declarative` to convert dygraph into static program + model.forward = dygraph.jit.declarative( + model.forward, input_spec=input_vars) + outputs = model.forward.concrete_program.outputs input_spec = [input_vars[i] for i in feed] configs = dygraph.jit.SaveLoadConfig() configs.separate_params = True diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index edd7dfcf939..03901dffcd3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -27,13 +27,12 @@ import types import numpy import six -from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator +from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer +from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func'] -program_translator = ProgramTranslator() -to_static_func = program_translator.get_func def is_builtin(func): @@ -63,7 +62,7 @@ def is_paddle_func(func): def convert_call(func): """ - Converts a function call which needs to be transformed to static fucntion. + Converts a function call which needs to be transformed to static function. Args: func (callable): A callable function or method to convert. @@ -98,6 +97,15 @@ def convert_call(func): func_self = None converted_call = None + # Function in convert_call may be decorated by another `@declarative`, + # in this case, unwraps it into a raw method or function. + if isinstance(func, StaticLayer): + instance = func._class_instance + if instance is not None: + func = func.dygraph_function.__get__(instance) + else: + func = func.dygraph_function + if is_builtin_len(func): return convert_len @@ -109,11 +117,27 @@ def convert_call(func): if func.__name__ == '': return func try: - global_funcs = set([ - fn for fn in func.__globals__.values() if inspect.isfunction(fn) - ]) - if func in global_funcs: - converted_call = to_static_func(func) + # Note(Aurelius84): Because `@declarative` returns a class instance instead of + # a function. This will modify the value referring to itself in `__globals__`. + + # For example: + # + # @declarative + # def foo(x): + # return x + # + # `foo` will be converted into a wrapper class, suppose as `StaticLayer`. + # And `foo.__globals__['foo']` will still return this `StaticLayer` instead of + # `foo` function. So `isinstance(fn, StaticLayer)` is added here. + global_functions = set() + for fn in func.__globals__.values(): + if inspect.isfunction(fn): + global_functions.add(fn) + elif isinstance(fn, StaticLayer): + global_functions.add(fn.dygraph_function) + + if func in global_functions: + converted_call = convert_to_static(func) func_self = getattr(func, '__self__', None) except AttributeError: # NOTE: @@ -127,7 +151,7 @@ def convert_call(func): converted_call = None elif inspect.ismethod(func): try: - converted_call = to_static_func(func) + converted_call = convert_to_static(func) func_self = getattr(func, '__self__', None) except (IOError, OSError): # NOTE: func may have been decorated. @@ -136,7 +160,7 @@ def convert_call(func): elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'): if hasattr(func, 'forward') and isinstance(func, Layer): try: - forward_func = to_static_func(func.forward) + forward_func = convert_to_static(func.forward) setattr(func, 'forward', forward_func) func_self = func except Exception: @@ -146,7 +170,7 @@ def convert_call(func): else: try: call_func = func.__class__.__call__ - converted_call = to_static_func(call_func) + converted_call = convert_to_static(call_func) func_self = func except Exception: # NOTE: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py new file mode 100644 index 00000000000..5540c63a85b --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -0,0 +1,311 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import six +import inspect +import numpy as np +import collections +import paddle +from paddle.fluid import core +from paddle.fluid.dygraph import layers +from paddle.fluid.layers.utils import flatten +from paddle.fluid.layers.utils import pack_sequence_as +from paddle.fluid.dygraph.base import switch_to_static_graph +from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs +from paddle.fluid.dygraph.dygraph_to_static.utils import type_name +from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code + + +class FunctionSpec(object): + """ + Wrapper class for a function for class method. + """ + + def __init__(self, function, input_spec=None): + self._dygraph_function = function + if input_spec is None: + self._input_spec = None + self._flat_input_spec = None + else: + self._input_spec = self._verify_input_spec(input_spec) + self._flat_input_spec = flatten(self._input_spec) + + # parse full argument names list. + self._arg_names, self._default_kwargs = parse_arg_and_kwargs(function) + + def unified_args_and_kwargs(self, args, kwargs): + """ + Moves kwargs with default value into arguments list to keep `args` contain the same length + value as function definition. + + For example: + + Given function definition: `def foo(x, a=1, b=2)`, + when calling it by `foo(23)`, the args is `[23]`, kwargs is `{a=1, b=2}`. + In this function, it will return args with `[23, 1, 2]`, kwargs with `{}` + + Args: + args(tuple): tuple of input arguments value of decorated function. + kwargs(dict): dict of input keyword arguments value of decorated function. + + Return: + New arguments tuple containing default kwargs value. + """ + if len(self._arg_names) < len(args): + error_msg = "The decorated function `{}` requires {} arguments: {}, but received {} with {}.".format( + self._dygraph_function.__name__, + len(self._arg_names), self._arg_names, len(args), args) + if args and inspect.isclass(args[0]): + error_msg += "\n\tMaybe the function has more than one decorator, we don't support this for now." + raise NotImplementedError(error_msg) + else: + raise ValueError(error_msg) + + args = list(args) + + for i in six.moves.range(len(args), len(self._arg_names)): + arg_name = self._arg_names[i] + if arg_name in kwargs: + args.append(kwargs[arg_name]) + del kwargs[arg_name] + else: + if arg_name not in self._default_kwargs: + raise ValueError( + "`{}()` requires `{}` arguments, but not found in input `args`: {} and `kwargs`: {}.". + format(self._dygraph_function.__name__, arg_name, args, + kwargs)) + args.append(self._default_kwargs[arg_name]) + + return tuple(args), kwargs + + def args_to_input_spec(self, args, kwargs): + """ + Converts input arguments into InputSpec. + + 1. If specific input_spec, use them to construct feed layers. + 2. If input_spec is None, consider all Tensor and Numpy.ndarray as feed layers + + Args: + args(tuple): tuple of input arguments value of function containing default kwargs value. + kwargs(dict): kwargs arguments received by **kwargs. + + Return: + Same nest structure with args by replacing value with InputSpec. + """ + input_with_spec = [] + + if self._input_spec is not None: + # Note: Because the value type and length of `kwargs` is uncertain. + # So we don't support to deal this case while specificing `input_spec` currently. + if kwargs: + raise ValueError( + "{} got unexpected keyword arguments: {}. Cannot trace the function when `input_spec` is specificed.". + format(self._dygraph_function.__name__, kwargs)) + + # Note: The length of `input_spec` can be greater than `args`, + # because `args` may contains non-tensor value merged form `kwargs` + # after `unified_args_and_kwargs`. + if len(args) < len(self._input_spec): + raise ValueError( + "Requires len(arguments) >= len(input_spec), but received len(args):{} < len(InputSpec): {}". + format(len(args), len(self._input_spec))) + + # replace argument with corresponding InputSpec. + input_with_spec = convert_to_input_spec(args, self._input_spec) + else: + for idx, input_var in enumerate(flatten(args)): + if isinstance(input_var, np.ndarray): + input_var = paddle.static.InputSpec.from_numpy(input_var) + elif isinstance(input_var, core.VarBase): + input_var = paddle.static.InputSpec.from_tensor(input_var) + + input_with_spec.append(input_var) + + input_with_spec = pack_sequence_as(args, input_with_spec) + + return input_with_spec + + @switch_to_static_graph + def to_static_inputs_with_spec(self, input_with_spec, main_program): + """ + Constructs feed layer by inputs with InputSpec information for main program. + + Args: + input_with_spec(tuple): input arguments by replacing argument with InputSpec. + main_program(Program): main program for inserting feed layer. + """ + flat_input_spec = flatten(input_with_spec) + + inputs = [] + block = main_program.global_block() + for i, var_spec in enumerate(flat_input_spec): + if isinstance(var_spec, paddle.static.InputSpec): + feed_layer = block.create_var( + # TODO(Aurelius84): consider a more elegant way to name this + name=var_spec.name or "feed_%s" % i, + shape=var_spec.shape, + dtype=var_spec.dtype, + is_data=True, + need_check_feed=False) + else: + feed_layer = var_spec + inputs.append(feed_layer) + + return pack_sequence_as(input_with_spec, inputs) + + def _verify_input_spec(self, input_spec): + """ + Verifies the `input_spec` and its element type is valid. + """ + if not isinstance(input_spec, (tuple, list)): + raise TypeError( + "The type(input_spec) should be one of (tuple, list), but received {}.". + format(type_name(input_spec))) + input_spec = tuple(input_spec) + for spec in flatten(input_spec): + if not isinstance(spec, paddle.static.InputSpec): + raise ValueError( + "The type(elem) from input_spec should be `InputSpec`, but received {}.". + format(type_name(spec))) + + return input_spec + + def __repr__(self): + return "function: {}({}), input_spec: {}".format( + self._dygraph_function.__name__, ','.join(self._arg_names), + self._input_spec) + + @property + def dygraph_function(self): + return self._dygraph_function + + @property + def args_name(self): + return self._arg_names + + @property + def input_spec(self): + return self._input_spec + + @property + def flat_input_spec(self): + return self._flat_input_spec + + @property + def code(self): + return func_to_source_code(self._dygraph_function) + + +def get_parameters(layer_instance, include_sublayer=True): + """ + Returns parameters of decorated layers. If set `include_sublayer` True, + the parameters created in sub layers will be added. + """ + params = collections.OrderedDict() + if layer_instance is not None: + if isinstance(layer_instance, layers.Layer): + if include_sublayer: + params = layer_instance.parameters() + names = [p.name for p in params] + params = collections.OrderedDict(zip(names, params)) + else: + params = layer_instance._parameters + else: + raise TypeError( + "Type of `layer_instance` should be nn.Layer, but received {}". + format(type_name(layer_instance))) + + return params + + +def get_buffers(layer_instance, include_sublayer=True): + """ + Returns Variable buffers of decorated layers. If set `include_sublayer` True, + the Variable buffers created in sub layers will be added. + """ + buffers = collections.OrderedDict() + if layer_instance is not None: + if isinstance(layer_instance, layers.Layer): + if include_sublayer: + buffers = layer_instance.buffers() + names = [buffer.name for buffer in buffers] + buffers = collections.OrderedDict(zip(names, buffers)) + else: + buffers = layer_instance._buffers + else: + raise TypeError( + "Type of `layer_instance` should be nn.Layer, but received {}". + format(type_name(layer_instance))) + return buffers + + +def convert_to_input_spec(inputs, input_spec): + """ + Replaces tensor in structured `inputs` by InputSpec in `input_spec`. + + Args: + inputs(list|dict): nested structure list or dict. + input_spec(list|dict): same nested structure list or dict as inputs. + + + Return: + Same structure with inputs by replacing the element with specified InputSpec. + """ + + def check_type_and_len(input, spec, check_length=False): + if type(input) is not type(spec): + raise TypeError('type(input) should be {}, but received {}.'.format( + type(spec), type(input))) + if check_length and len(input) < len(spec): + raise ValueError( + 'Requires len(inputs) >= len(input_spec), but received len(inputs):{} < len(input_spec):{}'. + format(len(inputs), len(input_spec))) + + if isinstance(input_spec, (tuple, list)): + input_with_spec = [] + check_type_and_len(inputs, input_spec, True) + + for i, spec in enumerate(input_spec): + out_spec = convert_to_input_spec(inputs[i], spec) + input_with_spec.append(out_spec) + + # Note: If the rest inputs contain tensor or numpy.ndarray + # without specific InputSpec, raise warning. + if len(inputs) > len(input_spec): + for rest_input in inputs[len(input_spec):]: + if isinstance(rest_input, (core.VarBase, np.ndarray)): + logging.warning( + "The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. " + "Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.". + format(type_name(rest_input))) + input_with_spec.extend(inputs[len(input_spec):]) + + return input_with_spec + elif isinstance(input_spec, dict): + input_with_spec = {} + check_type_and_len(inputs, input_spec, True) + for name, input in inputs.items(): + if name in input_spec: + input_with_spec[name] = convert_to_input_spec(input, + input_spec[name]) + else: + input_with_spec[name] = input + return input_with_spec + elif isinstance(input_spec, paddle.static.InputSpec): + return input_spec + else: + raise TypeError( + "The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.". + type_name(input_spec)) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index ceacba25375..d1699695fee 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -13,25 +13,23 @@ # limitations under the License. from __future__ import print_function - +import gast import collections +import logging import inspect +import six import textwrap import threading import warnings import gast -import numpy as np -from paddle.fluid import core -from paddle.fluid import executor from paddle.fluid import framework -from paddle.fluid import scope_guard -from paddle.fluid import unique_name -from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import layers +from paddle.fluid.data_feeder import check_type +from paddle.fluid.layers.utils import flatten from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import switch_to_static_graph -from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst +from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info @@ -41,13 +39,20 @@ from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_progr from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import type_name +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap -from paddle.fluid.layers.utils import flatten -from paddle.fluid.layers.utils import pack_sequence_as +from paddle.fluid.dygraph.dygraph_to_static.utils import make_hashable +from paddle.fluid.dygraph.dygraph_to_static.function_spec import FunctionSpec +from paddle.fluid.dygraph.dygraph_to_static.function_spec import get_buffers, get_parameters from paddle.fluid.wrapped_decorator import signature_safe_contextmanager __all__ = ['ProgramTranslator', 'convert_to_static'] +# For each traced function, we set `max_traced_program_count` = 10 to consider caching performance. +# Once exceeding the threshold, we will raise warning to users to make sure the conversion is as expected. +MAX_TRACED_PROGRAM_COUNT = 10 + class FunctionCache(object): """ @@ -136,100 +141,323 @@ def convert_to_static(function): return static_func -class FunctionSpec(object): - def __init__(self, func, args, kwargs): - self._dyfunc = func - self._args = args - self._kwargs = kwargs +class CacheKey(object): + """ + Cached key for ProgramCache. + """ - # TODO(liym27): func has multi layer decorator - dyfunc = getattr(func, '__wrapped__', func) - self._dyfunc_code = inspect.getsource(dyfunc) + __slots__ = ['function_spec', 'input_with_spec', 'class_instance'] - def is_method(self): - return self._args and isinstance(self._args[0], layers.Layer) + def __init__(self, function_spec, input_with_spec, class_instance): + """ + Initializes a cache key. - def parameters(self, include_sublayer=True): + Args: + functions_spec(FunctionSpec): a FunctionSpec instance of decorated function. + input_with_spec(list[InputSpec]): actual inputs with some arguments replaced by InputSpec. + class_instance(object): a instance of class `Layer`. """ - Returns parameters of decorated layers. If set `include_sublayer` True, - the parameters created in sub layers will be added. + self.function_spec = function_spec + self.input_with_spec = input_with_spec + self.class_instance = class_instance + + @classmethod + def from_func_and_args(cls, function_spec, args, kwargs, class_instance): """ - params = collections.OrderedDict() - if self.is_method(): - layer_instance = self._args[0] - if include_sublayer: - params = layer_instance.parameters() - names = [p.name for p in params] - params = collections.OrderedDict(zip(names, params)) + Generated a CacheKey instance by given inputs. + + Args: + functions_spec(FunctionSpec): a FunctionSpec instance of decorated function. + args(tuple): tuple of actual inputs arguments. + kwargs(dict): dict of actual inputs keyword arguments. + class_instance(object): a instance of class `Layer`. + """ + # 1. filter `self` in args + if args and isinstance(args[0], layers.Layer): + args = args[1:] + # 2. convert tensor and numpy array into InputSpec + _args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs) + input_with_spec = function_spec.args_to_input_spec(_args, _kwargs) + + # 3. check whether hit the cache or build a new program for the input arguments + return CacheKey(function_spec, input_with_spec, class_instance) + + def __hash__(self): + error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)." + return hash((id(self.function_spec), + make_hashable(self.input_with_spec, error_msg), + self.class_instance)) + + def __eq__(self, other): + return (type(self) is type(other)) and hash(self) == hash(other) + + def __neq__(self, other): + return not self == other + + def __repr__(self): + return "id(function_spec): {}, input_with_spec: {}, class_instance: {}".format( + id(self.function_spec), self.input_with_spec, self.class_instance) + + +def unwrap_decorators(func): + """ + Unwraps a decorated function and returns the decorator list and inner target. + """ + decorators = [] + cur = func + while True: + if isinstance(cur, StaticLayer): + decorators.append(cur) + # Note: if `cur` is a method, keep it as bound method of class. + instance = cur._class_instance + if instance is not None: + cur = cur.dygraph_function.__get__(instance) else: - params = layer_instance._parameters - return params + cur = cur.dygraph_function + else: + break + return decorators, cur - def buffers(self, include_sublayer=True): + +class StaticLayer(object): + """ + Wrapper class to Manage program conversion of decorated function. + + """ + + def __init__(self, function, input_spec=None): + """ + Initializes a `StaticLayer`. + + Args: + function(callable): A function or method that will be converted into static program. + input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None. """ - Returns Variable buffers of decorated layers. If set `include_sublayer` True, - the Variable buffers created in sub layers will be added. + # save the instance `self` while decorating a method of class. + if inspect.ismethod(function): + self._dygraph_function = getattr(function, '__func__') + self._class_instance = getattr(function, '__self__') + else: + self._dygraph_function = function + self._class_instance = None + + self._input_spec = input_spec + self._function_spec = FunctionSpec(function, input_spec) + self._program_cache = ProgramCache() + # Note: Hold a reference to ProgramTranslator for switching `enable_declarative`. + self._program_trans = ProgramTranslator() + + def __get__(self, instance, owner): """ - buffers = collections.OrderedDict() - if self.is_method(): - layer_instance = self._args[0] - if include_sublayer: - buffers = layer_instance.buffers() - names = [buffer.name for buffer in buffers] - buffers = collections.OrderedDict(zip(names, buffers)) + Overrides this method to parse the class instance and call bound method correctly. + + For example: + + ''' + class Net(Layer): + def __init__(self): + pass + + @paddle.jit.to_static + def forward(self, x, y): + return x + y + + net = Net() + out = net(x, y) + ''' + + In above case, `net(x, y)` will call `net.forward(x, y)` firstly that is a bound method + of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__` + to parse the class instance correctly instead of the `StaticLayer` instance. + """ + self._class_instance = instance + return self + + def __call__(self, *args, **kwargs): + """ + Supports to call the returned instance with input `args` and `kwargs` directly. + + Args: + *args(tuple): tuple of all input arguments from original decorated function. + **kwargs(dict): dict of all input keyward arguments from original decorated function. + + Return: + Outputs of decorated function. + """ + # 1. call dygraph function directly if not enable `declarative` + if not self._program_trans.enable_declarative: + warnings.warn( + "The decorator '@paddle.jit.to_static' doesn't work when setting ProgramTranslator.enable=False. " + "We will just return dygraph output.") + return self._call_dygraph_function(*args, **kwargs) + + # 2. trace ops from dygraph layers and cache the generated program. + args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs) + try: + concrete_program, partial_program_layer = self.get_concrete_program( + *args, **kwargs) + + # 3. synchronize self.training attribute. + if isinstance(self._class_instance, layers.Layer): + partial_program_layer.training = self._class_instance.training + + # 4. return outputs. + return partial_program_layer(args) + except Exception as e: + if not hasattr(e, ERROR_DATA): + # runtime error + attach_error_data(e, in_runtime=True) + error_data = getattr(e, ERROR_DATA, None) + if error_data: + new_exception = error_data.create_exception() + if six.PY3: + # NOTE(liym27): + # 1. Why `raise new_exception from None`? + # In Python 3, by default, an new exception is raised with trace information of the caught exception. + # This only raises new_exception and hides unwanted implementation details from tracebacks of the + # caught exception. + # 2. Use exec to bypass syntax error checking in Python 2. + + six.exec_("raise new_exception from None") + else: + raise new_exception else: - buffers = layer_instance._buffers - return buffers + raise - @switch_to_static_graph - def to_static_inputs(self, main_program): - inputs = [] - block = main_program.global_block() - for input_var in flatten(self.args): - if isinstance(input_var, np.ndarray): - feed_layer = block.create_var( - name=unique_name.generate('feed'), - shape=list(input_var.shape), - dtype=input_var.dtype, - is_data=True, - need_check_feed=False) - elif isinstance(input_var, core.VarBase): - feed_layer = block.create_var( - name=input_var.name, - shape=list(input_var.shape), - dtype=input_var.dtype, - stop_gradient=input_var.stop_gradient, - need_check_feed=False) + def _call_dygraph_function(self, *args, **kwargs): + """ + Calls dygraph function directly and returns the outputs. + + Args: + *args(tuple): tuple of all input arguments from original decorated function. + **kwargs(dict): dict of all input keyward arguments from original decorated function. + + Return: + Outputs of dygraph function. + """ + if self._class_instance is not None: + dygraph_function = self._dygraph_function.__get__( + self._class_instance) + else: + dygraph_function = self._dygraph_function + + return dygraph_function(*args, **kwargs) + + def get_concrete_program(self, *args, **kwargs): + """ + Returns traced concrete program and inner executable partial layer. + + Args: + *args(tuple): input arguments values or InputSpec + **kwargs(dict) : input kwargs values. + + Returns: + Traced ConcreteProgram and executable translated Layer. + """ + # 1. unify args/kwargs and replace Tensor with InputSpec + if len(args) != len(self._function_spec.args_name): + args, kwargs = self._function_spec.unified_args_and_kwargs(args, + kwargs) + input_with_spec = self._function_spec.args_to_input_spec(args, kwargs) + + # 2. generate cache key + cache_key = CacheKey(self._function_spec, input_with_spec, + self._class_instance) + + # 3. check whether hit the cache or build a new program for the input arguments + concrete_program, partial_program_layer = self._program_cache[cache_key] + return concrete_program, partial_program_layer + + def get_traced_count(self): + """ + Returns the number of traced programs for the decorated function. + """ + return len(self._program_cache) + + @property + def code(self): + """ + Returns the source code of transformed static function for debugging. + """ + static_func = convert_to_static(self._dygraph_function) + source_code = func_to_source_code(static_func) + return source_code + + @property + def dygraph_function(self): + """ + Returns the original decorated function. + """ + return self._dygraph_function + + @property + def concrete_program(self): + """ + Returns recent ConcreteProgram instance of decorated function. + """ + # if specific the `input_spec`, the length of program_cache will always 1, + # else, return the last one. + cached_program_len = len(self._program_cache) + # If specific `input_spec`, apply convertion from dygraph layers into static Program. + if cached_program_len == 0: + if len(self._function_spec.flat_input_spec) > 0: + input_spec = self._function_spec.input_spec + concrete_program, _ = self.get_concrete_program(*input_spec) + return concrete_program else: - feed_layer = input_var + raise ValueError("No valid transformed program for {}".format( + self._function_spec)) + # If more than one programs have been cached, return the recent converted program by default. + elif cached_program_len > 1: + logging.warning( + "Current {} has more than one cached programs: {}, the last traced progam will be return by default.". + format(self._function_spec, cached_program_len)) + + cache_key, (concrete_program, + partial_layer) = self._program_cache.last() + return concrete_program - inputs.append(feed_layer) - # Restores the nested structure as self.args - return pack_sequence_as(self.args, inputs) + @property + def inputs(self): + """ + Returns input tensors of recent converted static program. + """ + concrete_program = self.concrete_program + inputs = [ + var for var in flatten(concrete_program.inputs) + if isinstance(var, framework.Variable) + ] + return inputs @property - def dyfunc(self): - return self._dyfunc + def outputs(self): + """ + Returns output tensors of recent converted static program. + """ + concrete_program = self.concrete_program + outputs = [ + var for var in flatten(concrete_program.outputs) + if isinstance(var, framework.Variable) + ] + + return outputs @property - def args(self): - return self._args - - def __key(self): - # Note: if dygraph function is a method of class, - # consider instance info as hash key. - if self.is_method(): - # NOTE: we can use Layer's (instance + function code) as hash key. - # An instance will not hold two identical methods - return self._dyfunc_code, self._args[0] - else: - return self._dyfunc + def main_program(self): + """ + Returns recent converted static main program. + """ + concrete_program = self.concrete_program + main_program = concrete_program.main_program + return main_program - def __hash__(self): - return hash(self.__key()) + @property + def program_cache(self): + return self._program_cache - def __eq__(self, other): - return self.__key() == self.__key() + @property + def function_spec(self): + return self._function_spec # Flag that indicates whether running code under `@declarative` @@ -255,11 +483,17 @@ def _switch_declarative_mode_guard_(is_declarative=True): class ConcreteProgram(object): + + __slots__ = [ + 'inputs', 'outputs', 'main_program', "startup_program", "parameters", + "function" + ] + def __init__(self, inputs, outputs, parameters, - func, + function, main_program, startup_program=None): self.inputs = inputs @@ -267,17 +501,21 @@ class ConcreteProgram(object): self.main_program = main_program self.startup_program = startup_program self.parameters = parameters - self.func_spec = func + self.function = function @staticmethod @switch_to_static_graph - def from_func_spec(func_spec): + def from_func_spec(func_spec, input_spec, class_instance): """ Builds the main_program with specialized inputs and returns outputs of program as fetch_list. + + Args: + func_spec(FunctionSpec): A FunctionSpec instance for decorated function. + input_spec(list[InputSpec]): """ # Transforms dygraph function into static function and caches it. - dygraph_function = func_spec.dyfunc + dygraph_function = func_spec.dygraph_function static_func = convert_to_static(dygraph_function) main_program, startup_program = framework.Program(), framework.Program() @@ -291,15 +529,20 @@ class ConcreteProgram(object): with framework.program_guard(main_program, startup_program): with _switch_declarative_mode_guard_(is_declarative=True): # 1. Adds `fluid.data` layers for input if needed - inputs = func_spec.to_static_inputs(main_program) + inputs = func_spec.to_static_inputs_with_spec(input_spec, + main_program) + if class_instance: + inputs = tuple([class_instance] + list(inputs)) # 2. Gets all ParamBases and buffered VarBases in the function - all_parameters_and_buffers = list(func_spec.parameters().values( - )) + list(func_spec.buffers().values()) + all_parameters_and_buffers = list( + get_parameters(class_instance).values()) + list( + get_buffers(class_instance).values()) # 3. Builds program only once and returns the output Variables. - with param_guard(func_spec.parameters(False)), param_guard( - func_spec.buffers(False)): + with param_guard(get_parameters( + class_instance, False)), param_guard( + get_buffers(class_instance, False)): try: outputs = static_func(*inputs) except BaseException as e: @@ -317,7 +560,7 @@ class ConcreteProgram(object): inputs=inputs, outputs=outputs, parameters=all_parameters_and_buffers, - func=dygraph_function, + function=dygraph_function, main_program=main_program, startup_program=startup_program) @@ -330,27 +573,38 @@ class ProgramCache(object): def __init__(self): self._caches = collections.OrderedDict() - def _build_once(self, func_spec): - concrete_program = ConcreteProgram.from_func_spec(func_spec) + def _build_once(self, cache_key): + concrete_program = ConcreteProgram.from_func_spec( + func_spec=cache_key.function_spec, + input_spec=cache_key.input_with_spec, + class_instance=cache_key.class_instance) return concrete_program, partial_program_from(concrete_program) def __getitem__(self, item): - if not isinstance(item, FunctionSpec): - raise ValueError( - 'type(item) should be FunctionSpec, but received %s' % - type(item)) + if not isinstance(item, CacheKey): + raise ValueError('type(item) should be CacheKey, but received %s' % + type_name(item)) + if item not in self._caches: self._caches[item] = self._build_once(item) + # Note: raise warnings if number of traced program is more than `max_tracing_count` + current_tracing_count = len(self._caches) + if current_tracing_count > MAX_TRACED_PROGRAM_COUNT: + logging.warning( + "Current traced program number: {} > `max_tracing_count`:{}. Too much cached programs will bring expensive overhead. " + "The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors.". + format(current_tracing_count, MAX_TRACED_PROGRAM_COUNT)) + return self._caches[item] def get_program(self, item): - if not isinstance(item, FunctionSpec): + if not isinstance(item, CacheKey): raise ValueError( "Input item's type should be FunctionSpec, but received %s" % - type(item)) + type_name(item)) if item not in self._caches: raise RuntimeError( - "Failed to find program for input item, please decorate input function by `@declarative`." + "Failed to find program for input item, please decorate input function by `@paddle.jit.to_static`." ) return self._caches[item] @@ -360,6 +614,12 @@ class ProgramCache(object): key = next(reversed(self._caches.keys())) return key, self._caches[key] + def __len__(self): + return len(self._caches) + + def concrete_programs(self): + return [cp for key, (cp, _) in self._caches.iteritems()] + def synchronized(func): func.__lock__ = threading.Lock() @@ -508,9 +768,11 @@ class ProgramTranslator(object): "We will just return dygraph output.") return dygraph_func(*args, **kwargs) - function_spec = FunctionSpec(dygraph_func, args, kwargs) - concrete_program, partial_program_layer = self._program_cache[ - function_spec] + function_spec = FunctionSpec(dygraph_func) + cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs, + getattr(dygraph_func, + '__self__', None)) + _, partial_program_layer = self._program_cache[cache_key] if args and isinstance(args[0], layers.Layer): # Synchronize self.training attribute. @@ -624,8 +886,12 @@ class ProgramTranslator(object): "We will just return dygraph output.") return dygraph_func(*args, **kwargs) - func_spec = FunctionSpec(dygraph_func, args, kwargs) - concrete_program, _ = self._program_cache[func_spec] + function_spec = FunctionSpec(dygraph_func) + cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs, + getattr(dygraph_func, + '__self__', None)) + concrete_program, partial_program_layer = self._program_cache[cache_key] + # Note: concrete_program hold all input/output infos include non-Variable input_vars = [ var for var in concrete_program.inputs diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 21e05bc6faf..ba02a983f8e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -18,12 +18,14 @@ import ast import astor import atexit import copy +import collections import gast import inspect import os import six import tempfile import textwrap +import numpy as np from paddle.fluid import unique_name @@ -46,6 +48,77 @@ dygraph_class_to_static_api = { FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len' +# FullArgSpec is valid from Python3. Defined a Namedtuple to +# to make it available in Python2. +FullArgSpec = collections.namedtuple('FullArgSpec', [ + 'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults', + 'annotations' +]) + + +def getfullargspec(target): + if hasattr(inspect, "getfullargspec"): + return inspect.getfullargspec(target) + else: + argspec = inspect.getargspec(target) + return FullArgSpec( + args=argspec.args, + varargs=argspec.varargs, + varkw=argspec.keywords, + defaults=argspec.defaults, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}) + + +def parse_arg_and_kwargs(function): + """ + Returns full argument names as list. e.g ['x', 'y', 'z'] + """ + fullargspec = getfullargspec(function) + arg_names = fullargspec.args + if arg_names and 'self' == arg_names[0]: + arg_names = fullargspec.args[1:] + + # parse default kwargs + default_kwargs = {} + default_values = fullargspec.defaults + if default_values: + assert len(default_values) <= len(arg_names) + default_kwarg_names = arg_names[-len(default_values):] + default_kwargs = dict(zip(default_kwarg_names, default_values)) + + return arg_names, default_kwargs + + +def type_name(v): + return type(v).__name__ + + +def make_hashable(x, error_msg=None): + """ + Makes input `x` hashable. + + For some unhashable objects, such as `dict/list/np.ndarray`,applying hash function by using their values. + """ + if isinstance(x, (tuple, list)): + return tuple(map(make_hashable, x)) + + try: + hash(x) + except TypeError: + if isinstance(x, np.ndarray): + # Note: `tostring()` will return the binary data from np.ndarray that + # means different value will lead to different hash code. + return hash(x.tostring()) + elif isinstance(x, dict): + return tuple(map(make_hashable, x.values())) + + error_msg = error_msg or "Requires a hashable object." + raise ValueError(error_msg + " But received type: %s" % type_name(x)) + + return x + def _is_api_in_module_helper(obj, module_prefix): m = inspect.getmodule(obj) diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index ba27b2d1c63..7f3d450a49c 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -378,7 +378,7 @@ def _load_persistable_vars_by_program(model_path, new_var = framework._varbase_creator( type=each_var.type(), name=each_var.name(), - shpae=each_var.shape(), + shape=each_var.shape(), dtype=each_var.dtype(), persistable=True) if params_filename is None: @@ -636,7 +636,7 @@ class TranslatedLayer(layers.Layer): ) if not isinstance(persistable_vars, dict): raise TypeError( - "TranslatedLayer need to use persisatbale variable dict for initialization." + "TranslatedLayer need to use persistable variable dict for initialization." ) self._program_holder_dict = programs @@ -685,7 +685,7 @@ class TranslatedLayer(layers.Layer): # 1. load program desc & construct _ProgramHolder programs = _construct_program_holders(model_path, model_filename) - # 2. load layer parameters & parameter attirbutes + # 2. load layer parameters & parameter attributes persistable_vars = _construct_params_and_buffers( model_path, programs, separate_params, params_filename) @@ -753,7 +753,7 @@ class TranslatedLayer(layers.Layer): core.VarDesc.VarType.STEP_SCOPES, True) tmp_scope_vec.value().set_scope(program_holder.scope) - # 2. run prorgam by op + # 2. run program by op trace_program = program_holder.infer_program if self._is_test else program_holder.train_program end_op_index = program_holder.infer_program.block(0).op_size() framework._dygraph_tracer().trace_op( @@ -774,7 +774,7 @@ class TranslatedLayer(layers.Layer): # will be SelectedRows, not LoDTensor. But tracer will just # set param grad VarBase by forward VarBase(LoDTensor) # If we don't change grad_var type here, RunProgramOp need - # transform SelectedRows to LoDTensor forcely, it may not + # transform SelectedRows to LoDTensor forcibly, it may not # be user wanted result. for persistable_var in persistable_vars: grad_var_name = var.name + core.grad_var_suffix() diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 5a291df4700..b9d20f106ae 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -19,12 +19,12 @@ import pickle import warnings import six +import paddle from paddle.fluid import core from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph -from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA -from paddle.fluid.dygraph.dygraph_to_static.program_translator import FunctionSpec, ProgramTranslator +from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticLayer, unwrap_decorators from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer from paddle.fluid.dygraph.layers import Layer from paddle.fluid.executor import Executor, scope_guard @@ -128,7 +128,27 @@ def _dygraph_to_static_func_(dygraph_func): dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_) -def _declarative_(dygraph_func): +def copy_decorator_attrs(original_func, decorated_obj): + """ + Copies some necessary attributes from original function into decorated function. + + Args: + original_func(callable): the original decorated function. + decorated_obj(StaticLayer): the target decorated StaticLayer object. + """ + decorator_name = "declarative" + + decorated_obj.__name__ = original_func.__name__ + decorated_obj._decorator_name = decorator_name + decorated_obj.__wrapped__ = original_func + decorated_obj.__doc__ = original_func.__doc__ + if hasattr(original_func, "__module__"): + decorated_obj.__module__ = original_func.__module__ + + return decorated_obj + + +def declarative(function=None, input_spec=None): """ Converts imperative dygraph APIs into declarative function APIs. Decorator @declarative handles the Program and Executor of static mode and returns @@ -138,7 +158,9 @@ def _declarative_(dygraph_func): converted into declarative function as well. Args: - dygraph_func (callable): callable imperative function. + function (callable): callable imperative function. + input_spec(list[InputSpec]): list of InputSpec to specific the shape/dtype/name + information of each input Tensor. Returns: Tensor(s): containing the numerical result. @@ -167,37 +189,27 @@ def _declarative_(dygraph_func): """ - def __impl__(*args, **kwargs): - program_translator = ProgramTranslator() - if not program_translator.enable_declarative: - warnings.warn( - "The decorator 'declarative' doesn't work when setting ProgramTranslator.enable=False. " - "We will just return dygraph output.") - return dygraph_func(*args, **kwargs) - try: - return program_translator.get_output(dygraph_func, *args, **kwargs) - except Exception as e: - error_data = getattr(e, ERROR_DATA, None) - if error_data: - new_exception = error_data.create_exception() - if six.PY3: - # NOTE(liym27): - # 1. Why `raise new_exception from None`? - # In Python 3, by default, an new exception is raised with trace information of the caught exception. - # This only raises new_exception and hides unwanted implementation details from tracebacks of the - # caught exception. - # 2. Use exec to bypass syntax error checking in Python 2. - - six.exec_("raise new_exception from None") - else: - raise new_exception - else: - raise + def decorated(python_func): + """ + Decorates a python function into a StaticLayer object. + """ + # Step 1. unwrap the function if it is already decorated. + _, python_func = unwrap_decorators(python_func) - return __impl__ + # Step 2. copy some attributes from original python function. + static_layer = copy_decorator_attrs( + original_func=python_func, + decorated_obj=StaticLayer( + function=python_func, input_spec=input_spec)) + + return static_layer + # for usage: `declarative(foo, ...)` + if function is not None: + return decorated(function) -declarative = wrap_decorator(_declarative_) + # for usage: `@declarative` + return decorated class SaveLoadConfig(object): @@ -339,7 +351,7 @@ class SaveLoadConfig(object): # use SaveLoadconfig.output_spec model_path = "simplenet.example.model.output_spec" configs = fluid.dygraph.jit.SaveLoadConfig() - # only keep the predicted output in saved model, diccard loss + # only keep the predicted output in saved model, discard loss configs.output_spec = [out] fluid.dygraph.jit.save( @@ -374,7 +386,7 @@ class SaveLoadConfig(object): The name of file to save the translated program of target Layer. Default filename is :code:`__model__` . - Exampels: + Examples: .. code-block:: python import numpy as np @@ -444,7 +456,7 @@ class SaveLoadConfig(object): The name of file to save all persistable variables in target Layer. Default file name is :code:`__variables__` . - Exampels: + Examples: .. code-block:: python import numpy as np @@ -597,7 +609,7 @@ def save(layer, model_path, input_spec=None, configs=None): The default saved translated program file name is ``__model__``, and the default saved persistable variables file name is ``__variables__``, and it also saved some additional variable description information to file - ``__varibales.info__``, these additional information is used in fine-tuning. + ``__variables.info__``, these additional information is used in fine-tuning. The saved model can be loaded by follow APIs: - :ref:`api_imperative_jit_load` @@ -607,7 +619,7 @@ def save(layer, model_path, input_spec=None, configs=None): Args: layer (Layer): the Layer to be saved. The Layer should be decorated by `@declarative`. model_path (str): the directory to save the model. - input_spec (list[Varibale], optional): Describes the input of the saved model. + input_spec (list[Variable], optional): Describes the input of the saved model. It is the example inputs that will be passed to saved TranslatedLayer's forward function. If None, all input variables of the original Layer's forward function would be the inputs of the saved model. Default None. @@ -721,16 +733,17 @@ def save(layer, model_path, input_spec=None, configs=None): "The input input_spec should be 'list', but received input_spec's type is %s." % type(input_spec)) for var in input_spec: - if not isinstance(var, core.VarBase): + if not isinstance(var, (core.VarBase, Variable, + paddle.static.InputSpec)): raise TypeError( - "The element in input_spec list should be 'Variable', but received element's type is %s." + "The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s." % type(var)) # 2. get program of declarative Layer.forward - prog_cache = prog_translator.get_program_cache() - # make dummy args & kwargs, to get excepted FunctionSpec - layer_func = FunctionSpec(type(layer).forward, [layer], {}) - concrete_program, _ = prog_cache.get_program(layer_func) + if not isinstance(layer.forward, StaticLayer): + raise RuntimeError( + "layer.forward need to be decorated by `@declarative`.") + concrete_program = layer.forward.concrete_program # NOTE: we maintain the mapping of variable name to # structured name, the buffer variable (non-persistable) @@ -814,7 +827,7 @@ def load(model_path, configs=None): For some historical reasons, if you load model saved by :ref:`api_fluid_io_save_inference_model`, there will be the following limitations when using it in fine-tuning: 1. Imperative mode do not support LoDTensor. All original model's feed targets or parametars that depend on LoD are temporarily unavailable. - 2. All saved model's feed targets need to be passed into TranslatedLayer's forwrad function. + 2. All saved model's feed targets need to be passed into TranslatedLayer's forward function. 3. The variable's ``stop_gradient`` information is lost and can not be recovered. 4. The parameter's ``trainable`` information is lost and can not be recovered. diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py new file mode 100644 index 00000000000..4a689354f56 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -0,0 +1,250 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from paddle.static import InputSpec +import paddle.fluid as fluid +from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit + +import unittest + +program_trans = ProgramTranslator() + + +class SimpleNet(Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.linear = fluid.dygraph.Linear(10, 3) + + @declarative(input_spec=[InputSpec(shape=[None, 10], dtype='float32')]) + def forward(self, x, a=1, b=2): + y = self.inner_function(x) + return y + + # `declarative` is not essential, add it to test for robustness. + @declarative + def inner_function(self, x): + y = self.linear(x) + return y + + def add_func(self, x, y): + z = x + y + return z + + @declarative(input_spec=[[InputSpec([None, 10]), InputSpec([None, 10])]]) + def func_with_list(self, l): + x, y, int_val = l + z = x + y + z = z + int_val + return z + + @declarative(input_spec=[{ + 'x': InputSpec([None, 10]), + 'y': InputSpec([None, 10]) + }]) + def func_with_dict(self, d): + x = d['x'] + y = d['y'] + int_val = d['int_val'] + + z = x + y + z = z + int_val + + return z + + @declarative(input_spec=[[ + InputSpec([None]), { + 'x': InputSpec([None, 10]), + 'y': InputSpec([None, 10]) + } + ]]) + def func_with_list_dict(self, dl): + bias = dl[0] + x = dl[1]['x'] + y = dl[1]['y'] + + z = x + y + z = z + bias + + return z + + +class TestInputSpec(unittest.TestCase): + def setUp(self): + pass + + def test_with_input_spec(self): + with fluid.dygraph.guard(fluid.CPUPlace()): + x = to_variable(np.ones([4, 10]).astype('float32')) + y = to_variable(np.ones([4, 10]).astype('float32') * 2) + int_val = 4. + + net = SimpleNet() + + # 1. each method holds independent program cache + out = net(x) + self.assertTrue(len(net.forward.program_cache) == 1) + + # 2. test save load + jit.save(net, './simple_net') + infer_net = fluid.dygraph.jit.load('./simple_net') + pred = infer_net(x) + self.assertTrue(np.allclose(out.numpy(), pred.numpy())) + + # 3. we can decorate any method + x_2 = to_variable(np.ones([4, 20]).astype('float32')) + # uses `declarative(func)` instead of `@declarative` + net.add_func = declarative(net.add_func) + out = net.add_func(x_2, np.ones([20]).astype('float32')) + self.assertTrue(len(net.add_func.program_cache) == 1) + + # 5. test input with list + out = net.func_with_list([x, y, int_val]) + + # 6. test input with dict + out = net.func_with_dict({'x': x, 'y': y, 'int_val': int_val}) + + # 7. test input with lits contains dict + int_np = np.ones([1]).astype('float32') + out = net.func_with_list_dict([int_np, {'x': x, 'y': y}]) + + def test_with_error(self): + with fluid.dygraph.guard(fluid.CPUPlace()): + x = to_variable(np.ones([4, 10]).astype('float32')) + y = to_variable(np.ones([4, 10]).astype('float32') * 2) + int_val = 4. + + net = SimpleNet() + + # 1. kwargs and input_spec should not be specificed in same time + with self.assertRaises(ValueError): + net(x, a=1, other_kwarg=2) + + # 2. requires len(input_spec) <= len(args) + with self.assertRaises(ValueError): + net.add_func = declarative( + net.add_func, + input_spec=[ + InputSpec([-1, 10]), InputSpec([-1, 10]), + InputSpec([10]) + ]) + net.add_func(x, y) + + def test_concrete_program(self): + with fluid.dygraph.guard(fluid.CPUPlace()): + x = to_variable(np.ones([4, 10]).astype('float32')) + y = to_variable(np.ones([4, 10]).astype('float32') * 2) + int_val = 4. + + net = SimpleNet() + # We can get concrete_program by specificing InputSpec information. Faking input is no need. + net.add_func = declarative( + net.add_func, + input_spec=[ + InputSpec([-1, 10]), InputSpec( + [-1, 10], name='y') + ]) + cp1 = net.add_func.concrete_program + self.assertTrue(cp1.inputs[-1].shape == (-1, 10)) + self.assertTrue(cp1.inputs[-1].name == 'y') + + # generate another program + net.add_func = declarative( + net.add_func, + input_spec=[InputSpec([10]), InputSpec( + [10], name='label')]) + cp2 = net.add_func.concrete_program + self.assertTrue(cp2.inputs[-1].shape == (10, )) + self.assertTrue(cp2.inputs[-1].name == 'label') + # Note(Aurelius84): New instance will be returned if we use `declarative(foo)` every time. + # So number of cache program is 1. + self.assertTrue(len(net.add_func.program_cache) == 1) + self.assertTrue(cp1 != cp2) + + +def foo_func(a, b, c=1, d=2): + z = a + b + return z + + +class TestDifferentInputSpecCacheProgram(unittest.TestCase): + def test_with_different_input(self): + with fluid.dygraph.guard(fluid.CPUPlace()): + x_data = np.ones([16, 10]).astype('float32') + y_data = np.ones([10]).astype('float32') * 2 + z_data = np.ones([10]).astype('float32') * 2.2 + + foo = declarative(foo_func) + + # [16, 10] + [10] (varbase) + out_1 = foo(to_variable(x_data), to_variable(y_data)) + self.assertTrue(np.allclose(x_data + y_data, out_1.numpy())) + self.assertTrue(len(foo.program_cache) == 1) + + # [16, 10] + [10] (numpy) + out_2 = foo(to_variable(x_data), y_data) + self.assertTrue(np.allclose(x_data + y_data, out_2.numpy())) + self.assertTrue(len(foo.program_cache) == 1) + + # [16, 10] + [10] (numpy) + out_3 = foo(to_variable(x_data), z_data) + self.assertTrue(np.allclose(x_data + z_data, out_3.numpy())) + # hit cache program + self.assertTrue(len(foo.program_cache) == 1) + + # [16, 10] + [10] (numpy) with other different arguments (c=3) + out_4 = foo(to_variable(x_data), z_data, 3) + self.assertTrue(np.allclose(x_data + z_data, out_4.numpy())) + # create a new program + self.assertTrue(len(foo.program_cache) == 2) + + def test_get_concrete_program(self): + + foo = declarative(foo_func) + + # 1. specific InputSpec for `x`/`y` + concrete_program_1 = foo.get_concrete_program( + InputSpec([None, 10]), InputSpec([10])) + print(concrete_program_1) + self.assertTrue(len(foo.program_cache) == 1) + + # 2. specific `c`/`d` explicitly with same default value + concrete_program_2 = foo.get_concrete_program( + InputSpec([None, 10]), InputSpec([10]), 1, 2) + self.assertTrue(concrete_program_2 == concrete_program_1) + self.assertTrue(len(foo.program_cache) == 1) + + # 3. specific `c` = 2 + concrete_program_3 = foo.get_concrete_program( + InputSpec([None, 10]), InputSpec([10]), c=2) + self.assertTrue(concrete_program_3 != concrete_program_1) + self.assertTrue(len(foo.program_cache) == 2) + + # 4. specific x.shape = [10] + concrete_program_4 = foo.get_concrete_program( + InputSpec([10]), InputSpec([10])) + self.assertTrue(concrete_program_4 != concrete_program_1) + self.assertTrue(len(foo.program_cache) == 3) + + # 5. only specific InputSpec of x + with self.assertRaises(ValueError): + concrete_program_5 = foo.get_concrete_program(InputSpec([10])) + + # 6. specific unknown kwargs `e`=4 + concrete_program_5 = foo.get_concrete_program( + InputSpec([10]), InputSpec([10]), e=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py new file mode 100644 index 00000000000..88697bc1b36 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py @@ -0,0 +1,116 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.static import InputSpec +from paddle.fluid.dygraph.dygraph_to_static.function_spec import FunctionSpec + +from test_declarative import foo_func + +import unittest + + +class TestFunctionSpec(unittest.TestCase): + def test_constructor(self): + foo_spec = FunctionSpec(foo_func) + args_name = foo_spec.args_name + self.assertListEqual(args_name, ['a', 'b', 'c', 'd']) + self.assertTrue(foo_spec.dygraph_function == foo_func) + self.assertTrue(foo_spec.input_spec is None) + + def test_verify_input_spec(self): + a_spec = InputSpec([None, 10], name='a') + b_spec = InputSpec([10], name='b') + + # type(input_spec) should be list or tuple + with self.assertRaises(TypeError): + foo_spec = FunctionSpec(foo_func, input_spec=a_spec) + + # each element of input_spec should be `InputSpec` + with self.assertRaises(ValueError): + foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, 10]) + + foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec]) + self.assertTrue(len(foo_spec.flat_input_spec) == 2) + + def test_unified_args_and_kwargs(self): + foo_spec = FunctionSpec(foo_func) + # case 1: foo(10, 20, c=4) + args, kwargs = foo_spec.unified_args_and_kwargs([10, 20], {'c': 4}) + self.assertTupleEqual(args, (10, 20, 4, 2)) + self.assertTrue(len(kwargs) == 0) + + # case 2: foo(a=10, b=20, d=4) + args, kwargs = foo_spec.unified_args_and_kwargs( + [], {'a': 10, + 'b': 20, + 'd': 4}) + self.assertTupleEqual(args, (10, 20, 1, 4)) + self.assertTrue(len(kwargs) == 0) + + # case 3: foo(10, b=20) + args, kwargs = foo_spec.unified_args_and_kwargs([10], {'b': 20}) + self.assertTupleEqual(args, (10, 20, 1, 2)) + self.assertTrue(len(kwargs) == 0) + + # assert len(self._arg_names) >= len(args) + with self.assertRaises(ValueError): + foo_spec.unified_args_and_kwargs([10, 20, 30, 40, 50], {'c': 4}) + + # assert arg_name should be in kwargs + with self.assertRaises(ValueError): + foo_spec.unified_args_and_kwargs([10], {'c': 4}) + + def test_args_to_input_spec(self): + a_spec = InputSpec([None, 10], name='a') + b_spec = InputSpec([10], name='b') + + a_tensor = paddle.static.data(name='a_var', shape=[4, 10]) + b_tensor = paddle.static.data(name='b_var', shape=[4, 10]) + kwargs = {'c': 1, 'd': 2} + + # case 1 + foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec]) + input_with_spec = foo_spec.args_to_input_spec( + (a_tensor, b_tensor, 1, 2), {}) + self.assertTrue(len(input_with_spec) == 4) + self.assertTrue(input_with_spec[0] == a_spec) # a + self.assertTrue(input_with_spec[1] == b_spec) # b + self.assertTrue(input_with_spec[2] == 1) # c + self.assertTrue(input_with_spec[3] == 2) # d + + # case 2 + foo_spec = FunctionSpec(foo_func, input_spec=[a_spec]) + input_with_spec = foo_spec.args_to_input_spec((a_tensor, b_tensor), {}) + self.assertTrue(len(input_with_spec) == 2) + self.assertTrue(input_with_spec[0] == a_spec) # a + self.assertTupleEqual(input_with_spec[1].shape, (4, 10)) # b.shape + self.assertEqual(input_with_spec[1].name, 'b_var') # b.name + + # case 3 + # assert kwargs is None if set `input_spec` + foo_spec = FunctionSpec(foo_func, input_spec=[a_spec]) + with self.assertRaises(ValueError): + input_with_spec = foo_spec.args_to_input_spec((a_tensor, b_tensor), + {'c': 4}) + + # case 4 + # assert len(args) >= len(self._input_spec) + foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec]) + with self.assertRaises(ValueError): + input_with_spec = foo_spec.args_to_input_spec((a_tensor, ), {}) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py index 3da60e955de..f0fbe54f9db 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py @@ -133,7 +133,7 @@ class TestWithTrainAndEval(unittest.TestCase): x = fluid.dygraph.to_variable(x_data) linear_net(x) - _, partial_layer = program_translator.get_program_cache().last()[-1] + _, partial_layer = linear_net.forward.program_cache.last()[-1] # check default mode is for training self.assertEqual(partial_layer.program, partial_layer._train_program) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py index 0386b7c7a17..6cf59c030c0 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py @@ -133,7 +133,7 @@ class TestPartialProgramRaiseError(unittest.TestCase): x = fluid.dygraph.to_variable(x_data) out = net(x) - program_cache = program_translator.get_program_cache() + program_cache = SimpleFcLayer.forward.program_cache _, (concrete_program, _) = program_cache.last() params = concrete_program.parameters diff --git a/python/paddle/fluid/tests/unittests/test_input_spec.py b/python/paddle/fluid/tests/unittests/test_input_spec.py new file mode 100644 index 00000000000..e329a37488a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_input_spec.py @@ -0,0 +1,113 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid as fluid +from paddle.static import InputSpec +from paddle.fluid.framework import core, convert_np_dtype_to_dtype_ + + +class TestInputSpec(unittest.TestCase): + def test_default(self): + tensor_spec = InputSpec([3, 4]) + self.assertEqual(tensor_spec.dtype, + convert_np_dtype_to_dtype_('float32')) + self.assertEqual(tensor_spec.name, None) + + def test_from_tensor(self): + x_bool = fluid.layers.fill_constant(shape=[1], dtype='bool', value=True) + bool_spec = InputSpec.from_tensor(x_bool) + self.assertEqual(bool_spec.dtype, x_bool.dtype) + self.assertEqual(bool_spec.shape, x_bool.shape) + self.assertEqual(bool_spec.name, x_bool.name) + + bool_spec2 = InputSpec.from_tensor(x_bool, name='bool_spec') + self.assertEqual(bool_spec2.name, bool_spec2.name) + + def test_from_numpy(self): + x_numpy = np.ones([10, 12]) + x_np_spec = InputSpec.from_numpy(x_numpy) + self.assertEqual(x_np_spec.dtype, + convert_np_dtype_to_dtype_(x_numpy.dtype)) + self.assertEqual(x_np_spec.shape, x_numpy.shape) + self.assertEqual(x_np_spec.name, None) + + x_numpy2 = np.array([1, 2, 3, 4]).astype('int64') + x_np_spec2 = InputSpec.from_numpy(x_numpy2, name='x_np_int64') + self.assertEqual(x_np_spec2.dtype, + convert_np_dtype_to_dtype_(x_numpy2.dtype)) + self.assertEqual(x_np_spec2.shape, x_numpy2.shape) + self.assertEqual(x_np_spec2.name, 'x_np_int64') + + def test_shape_with_none(self): + tensor_spec = InputSpec([None, 4, None], dtype='int8', name='x_spec') + self.assertEqual(tensor_spec.dtype, convert_np_dtype_to_dtype_('int8')) + self.assertEqual(tensor_spec.name, 'x_spec') + self.assertEqual(tensor_spec.shape, (-1, 4, -1)) + + def test_shape_raise_error(self): + # 1. shape should only contain int and None. + with self.assertRaises(ValueError): + tensor_spec = InputSpec(['None', 4, None], dtype='int8') + + # 2. shape should be type `list` or `tuple` + with self.assertRaises(TypeError): + tensor_spec = InputSpec(4, dtype='int8') + + # 3. len(shape) should be greater than 0. + with self.assertRaises(ValueError): + tensor_spec = InputSpec([], dtype='int8') + + def test_batch_and_unbatch(self): + tensor_spec = InputSpec([10]) + # insert batch_size + batch_tensor_spec = tensor_spec.batch(16) + self.assertEqual(batch_tensor_spec.shape, (16, 10)) + + # unbatch + unbatch_spec = batch_tensor_spec.unbatch() + self.assertEqual(unbatch_spec.shape, (10, )) + + # 1. `unbatch` requires len(shape) > 1 + with self.assertRaises(ValueError): + unbatch_spec.unbatch() + + # 2. `batch` requires len(batch_size) == 1 + with self.assertRaises(ValueError): + tensor_spec.batch([16, 12]) + + # 3. `batch` requires type(batch_size) == int + with self.assertRaises(TypeError): + tensor_spec.batch('16') + + def test_eq_and_hash(self): + tensor_spec_1 = InputSpec([10, 16], dtype='float32') + tensor_spec_2 = InputSpec([10, 16], dtype='float32') + tensor_spec_3 = InputSpec([10, 16], dtype='float32', name='x') + tensor_spec_4 = InputSpec([16], dtype='float32', name='x') + + # override ``__eq__`` according to [shape, dtype, name] + self.assertTrue(tensor_spec_1 == tensor_spec_2) + self.assertTrue(tensor_spec_1 != tensor_spec_3) # different name + self.assertTrue(tensor_spec_3 != tensor_spec_4) # different shape + + # override ``__hash__`` according to [shape, dtype] + self.assertTrue(hash(tensor_spec_1) == hash(tensor_spec_2)) + self.assertTrue(hash(tensor_spec_1) == hash(tensor_spec_3)) + self.assertTrue(hash(tensor_spec_3) != hash(tensor_spec_4)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 4d7711a5df9..2b79659b9c6 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -19,11 +19,11 @@ import pickle import unittest import numpy as np -import paddle +from paddle.static import InputSpec import paddle.fluid as fluid from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import declarative, ProgramTranslator -from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME +from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME BATCH_SIZE = 32 BATCH_NUM = 10 @@ -156,7 +156,7 @@ class TestJitSaveLoad(unittest.TestCase): def load_dygraph_state_dict(self, train_layer): train_layer.eval() - # contruct new model + # construct new model new_layer = LinearNet(784, 1) model_dict, _ = fluid.dygraph.load_dygraph(self.model_path) new_layer.set_dict(model_dict) @@ -176,7 +176,7 @@ class TestJitSaveLoad(unittest.TestCase): model_path=self.model_path, input_spec=example_inputs) - def test_load_dygraoh_no_path(self): + def test_load_dygraph_no_path(self): model_path = "model.test_jit_save_load.no_path" new_layer = LinearNet(784, 1) with self.assertRaises(ValueError): @@ -202,6 +202,92 @@ class TestJitSaveLoad(unittest.TestCase): model_dict, _ = fluid.dygraph.load_dygraph(model_path) +class LinearNetMultiInput(fluid.dygraph.Layer): + def __init__(self, in_size, out_size): + super(LinearNetMultiInput, self).__init__() + self._linear1 = Linear(in_size, out_size) + # self._linear2 = Linear(in_size, out_size) + + @declarative(input_spec=[ + InputSpec( + [None, 8], dtype='float32'), InputSpec( + [None, 8], dtype='float32') + ]) + def forward(self, x, y): + x_out = self._linear1(x) + y_out = self._linear1(y) + loss = fluid.layers.mean(x_out + y_out) + return x_out, y_out, loss + + +class TestSaveLoadWithInputSpec(unittest.TestCase): + def setUp(self): + # enable dygraph mode + fluid.enable_dygraph() + + def test_with_input_spec(self): + net = LinearNetReturnLoss(8, 8) + # set x.shape = [None, 8] + net.forward = declarative( + net.forward, input_spec=[InputSpec( + [None, 8], name='x')]) + + model_path = "model.input_spec.output_spec" + configs = fluid.dygraph.jit.SaveLoadConfig() + # check inputs and outputs + self.assertTrue(len(net.forward.inputs) == 1) + input_x = net.forward.inputs[0] + self.assertTrue(input_x.shape == (-1, 8)) + self.assertTrue(input_x.name == 'x') + + # 1. prune loss + configs.output_spec = net.forward.outputs[:1] + fluid.dygraph.jit.save(net, model_path, configs=configs) + + # 2. load to infer + infer_layer = fluid.dygraph.jit.load(model_path, configs=configs) + x = fluid.dygraph.to_variable( + np.random.random((4, 8)).astype('float32')) + pred = infer_layer(x) + + def test_multi_in_out(self): + net = LinearNetMultiInput(8, 8) + + model_path = "model.multi_inout.output_spec1" + configs = fluid.dygraph.jit.SaveLoadConfig() + # 1. check inputs and outputs + self.assertTrue(len(net.forward.inputs) == 2) + input_x = net.forward.inputs[0] + input_y = net.forward.inputs[1] + self.assertTrue(input_x.shape == (-1, 8)) + self.assertTrue(input_y.shape == (-1, 8)) + + # 2. prune loss + configs.output_spec = net.forward.outputs[:2] + fluid.dygraph.jit.save(net, model_path, configs=configs) + + # 3. load to infer + infer_layer = fluid.dygraph.jit.load(model_path, configs=configs) + x = fluid.dygraph.to_variable( + np.random.random((4, 8)).astype('float32')) + y = fluid.dygraph.to_variable( + np.random.random((4, 8)).astype('float32')) + # 4. predict + pred_x, pred_y = infer_layer(x, y) + + # 1. prune y and loss + model_path = "model.multi_inout.output_spec2" + configs.output_spec = net.forward.outputs[:1] + fluid.dygraph.jit.save(net, model_path, [input_x], configs) + # 2. load again + infer_layer2 = fluid.dygraph.jit.load(model_path, configs=configs) + # 3. predict + pred_xx = infer_layer2(x) + + # 4. assert pred_x == pred_xx + self.assertTrue(np.allclose(pred_x.numpy(), pred_xx.numpy())) + + class TestJitSaveLoadConfig(unittest.TestCase): def setUp(self): # enable dygraph mode diff --git a/python/paddle/incubate/hapi/model.py b/python/paddle/incubate/hapi/model.py index 977f9233a95..e4a6b03f7aa 100644 --- a/python/paddle/incubate/hapi/model.py +++ b/python/paddle/incubate/hapi/model.py @@ -1607,10 +1607,7 @@ class Model(object): % type(layer)) # 2. get program of declarative Layer.forward - prog_cache = prog_translator.get_program_cache() - # make dummy args & kwargs, to get excepted FunctionSpec - layer_func = FunctionSpec(type(layer).forward, [layer], {}) - concrete_program, _ = prog_cache.get_program(layer_func) + concrete_program = layer.forward.concrete_program # NOTE: we maintain the mapping of variable name to # structured name, the buffer variable (non-persistable) @@ -1742,12 +1739,13 @@ class Model(object): out_specs = [] if specs is None: - # If not specific specs of `Input`, using argument names of `forward` function - # to generate `Input`. + # Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function + # to generate `Input`. But how can we know the actual shape of each input tensor? if is_input: out_specs = [ - Input(name=n) for n in extract_args(self.network.forward) - if n != 'self' + Input( + name=n, shape=[None]) + for n in extract_args(self.network.forward) if n != 'self' ] else: out_specs = to_list(specs) diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index 06b9c7cdbef..eb70320ea75 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle -import numpy as np import six -from paddle.fluid import core +import paddle +from paddle.fluid import core, Variable from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.data_feeder import check_dtype, check_type +from paddle.fluid.data_feeder import check_type +from paddle.fluid.framework import convert_np_dtype_to_dtype_ __all__ = ['data', 'InputSpec'] @@ -41,7 +41,7 @@ def data(name, shape, dtype=None, lod_level=0): size. For example, it is useful to set changeable batch size as "None" or -1. dtype (np.dtype|str, optional): The type of the data. Supported dtype: bool, float16, float32, float64, int8, int16, int32, int64, - uint8. Default: None. When `dtype` is not set, the dtype will get + uint8. Default: None. When `dtype` is not set, the dtype will get from the global dtype by `paddle.get_default_dtype()`. lod_level (int, optional): The LoD level of the LoDTensor. Usually users don't have to set this value. For more details about when and how to @@ -54,13 +54,12 @@ def data(name, shape, dtype=None, lod_level=0): .. code-block:: python import numpy as np - import paddle.fluid as fluid import paddle # Creates a variable with fixed size [3, 2, 1] # User can only feed data of the same shape to x # the dtype is not set, so it will set "float32" by - # paddle.get_default_dtype(). You can use paddle.get_default_dtype() to + # paddle.get_default_dtype(). You can use paddle.get_default_dtype() to # change the global dtype x = paddle.static.data(name='x', shape=[3, 2, 1]) @@ -75,8 +74,8 @@ def data(name, shape, dtype=None, lod_level=0): # and fetch z, like implementing "1 + 1 = 2" in PaddlePaddle feed_data = np.ones(shape=[3, 2, 1], dtype=np.float32) - exe = fluid.Executor(fluid.CPUPlace()) - out = exe.run(fluid.default_main_program(), + exe = paddle.static.Executor(paddle.framework.CPUPlace()) + out = exe.run(paddle.static.default_main_program(), feed={ 'x': feed_data, 'y': feed_data @@ -120,11 +119,13 @@ def data(name, shape, dtype=None, lod_level=0): class InputSpec(object): """ - Define input specification of the model. + InputSpec describes the signature information of the model input, such as ``shape`` , ``dtype`` , ``name`` . + + This interface is often used to specify input tensor information of models in high-level API. + It's also used to specify the tensor information for each input parameter of the forward function + decorated by `@paddle.jit.to_static`. Args: - name (str): The name/alias of the variable, see :ref:`api_guide_Name` - for more details. shape (tuple(integers)|list[integers]): List|Tuple of integers declaring the shape. You can set "None" or -1 at a dimension to indicate the dimension can be of any size. For example, @@ -132,18 +133,28 @@ class InputSpec(object): dtype (np.dtype|str, optional): The type of the data. Supported dtype: bool, float16, float32, float64, int8, int16, int32, int64, uint8. Default: float32. + name (str): The name/alias of the variable, see :ref:`api_guide_Name` + for more details. Examples: .. code-block:: python - from paddle.static import InputSpec + from paddle.static import InputSpec + + input = InputSpec([None, 784], 'float32', 'x') + label = InputSpec([None, 1], 'int64', 'label') - input = InputSpec([None, 784], 'float32', 'x') - label = InputSpec([None, 1], 'int64', 'label') + print(input) # InputSpec(shape=(-1, 784), dtype=VarType.FP32, name=x) + print(label) # InputSpec(shape=(-1, 1), dtype=VarType.INT64, name=label) """ - def __init__(self, shape=None, dtype='float32', name=None): - self.shape = shape + def __init__(self, shape, dtype='float32', name=None): + # replace `None` in shape with -1 + self.shape = self._verify(shape) + # convert dtype into united represention + if dtype is not None: + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) self.dtype = dtype self.name = name @@ -153,3 +164,167 @@ class InputSpec(object): def __repr__(self): return '{}(shape={}, dtype={}, name={})'.format( type(self).__name__, self.shape, self.dtype, self.name) + + @classmethod + def from_tensor(cls, tensor, name=None): + """ + Generates a InputSpec based on the description of input tensor. + + Args: + tensor(Tensor): the source tensor to generate a InputSpec instance + + Returns: + A InputSpec instance generated from Tensor. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + from paddle.static import InputSpec + + paddle.disable_static() + + x = paddle.to_tensor(np.ones([2, 2], np.float32)) + x_spec = InputSpec.from_tensor(x, name='x') + print(x_spec) # InputSpec(shape=(2, 2), dtype=VarType.FP32, name=x) + + """ + if isinstance(tensor, (Variable, core.VarBase)): + return cls(tensor.shape, tensor.dtype, name or tensor.name) + else: + raise ValueError( + "Input `tensor` should be a Tensor, but received {}.".format( + type(tensor).__name__)) + + @classmethod + def from_numpy(cls, ndarray, name=None): + """ + Generates a InputSpec based on the description of input np.ndarray. + + Args: + tensor(Tensor): the source numpy ndarray to generate a InputSpec instance + + Returns: + A InputSpec instance generated from Tensor. + + Examples: + .. code-block:: python + + import numpy as np + from paddle.static import InputSpec + + x = np.ones([2, 2], np.float32) + x_spec = InputSpec.from_numpy(x, name='x') + print(x_spec) # InputSpec(shape=(2, 2), dtype=VarType.FP32, name=x) + + """ + return cls(ndarray.shape, ndarray.dtype, name) + + def batch(self, batch_size): + """ + Inserts `batch_size` in front of the `shape`. + + Args: + batch_size(int): the inserted integer value of batch size. + + Returns: + The original InputSpec instance by inserting `batch_size` in front of `shape`. + + Examples: + .. code-block:: python + + from paddle.static import InputSpec + + x_spec = InputSpec(shape=[64], dtype='float32', name='x') + x_spec.batch(4) + print(x_spec) # InputSpec(shape=(4, 64), dtype=VarType.FP32, name=x) + + """ + if isinstance(batch_size, (list, tuple)): + if len(batch_size) != 1: + raise ValueError( + "Length of batch_size: {} shall be 1, but received {}.". + format(batch_size, len(batch_size))) + batch_size = batch_size[1] + elif not isinstance(batch_size, six.integer_types): + raise TypeError("type(batch_size) shall be `int`, but received {}.". + format(type(batch_size).__name__)) + + new_shape = [batch_size] + list(self.shape) + self.shape = tuple(new_shape) + + return self + + def unbatch(self): + """ + Removes the first element of `shape`. + + Returns: + The original InputSpec instance by removing the first element of `shape` . + + Examples: + .. code-block:: python + + from paddle.static import InputSpec + + x_spec = InputSpec(shape=[4, 64], dtype='float32', name='x') + x_spec.unbatch() + print(x_spec) # InputSpec(shape=(64,), dtype=VarType.FP32, name=x) + + """ + if len(self.shape) == 0: + raise ValueError( + "Not support to unbatch a InputSpec when len(shape) == 0.") + + self.shape = self._verify(self.shape[1:]) + return self + + def _verify(self, shape): + """ + Verifies the input shape and modifies `None` into `-1`. + """ + if not isinstance(shape, (list, tuple)): + raise TypeError( + "Type of `shape` in InputSpec should be one of (tuple, list), but received {}.". + format(type(shape).__name__)) + if len(shape) == 0: + raise ValueError( + "`shape` in InputSpec should contain at least 1 element, but received {}.". + format(shape)) + + for i, ele in enumerate(shape): + if ele is not None: + if not isinstance(ele, six.integer_types): + raise ValueError( + "shape[{}] should be an `int`, but received `{}`:{}.". + format(i, type(ele).__name__, ele)) + if ele is None or ele < -1: + shape[i] = -1 + + return tuple(shape) + + def __hash__(self): + # Note(Aurelius84): `name` is not considered as a field to compute hashkey. + # Because it's no need to generate a new program in following cases while using + # @paddle.jit.to_static. + # + # Case 1: + # foo(x_var) + # foo(y_var) + # x_var and y_var hold same shape and dtype, they should share a same program. + # + # + # Case 2: + # foo(x_var) + # foo(x_np) # x_np is a numpy.ndarray. + # x_var and x_np hold same shape and dtype, they should also share a same program. + return hash((tuple(self.shape), self.dtype)) + + def __eq__(self, other): + slots = ['shape', 'dtype', 'name'] + return (type(self) is type(other) and all( + getattr(self, attr) == getattr(other, attr) for attr in slots)) + + def __ne__(self, other): + return not self == other -- GitLab