未验证 提交 f0561368 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Support InputSpec and Return callable class instance in @declarative (#25960)

* add InputSpec

* add unittest for tensorSpec and SimpleNet
上级 89d7d866
......@@ -15,6 +15,7 @@
import logging
import numpy as np
import sys
import paddle
from paddle.fluid import dygraph
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.nn import Linear
......@@ -195,13 +196,16 @@ class ImperativeQuantAware(object):
with dygraph.guard():
model.eval()
input_vars = []
for shape, dtype in zip(input_shape, input_dtype):
raw_data = np.random.random(shape)
input_data = raw_data[np.newaxis, :].astype(
dtype) if append_batch_size else raw_data.astype(dtype)
input_var = dygraph.to_variable(input_data)
input_vars.append(input_var)
outputs = prog_trans.get_output(model.forward, model, *input_vars)
for i, (shape, dtype) in enumerate(zip(input_shape, input_dtype)):
if append_batch_size:
shape = [None] + list(shape)
# Note(Aurelius84): need a elegant way to name this.
in_spec = paddle.static.InputSpec(shape, dtype, 'feed_%d' % i)
input_vars.append(in_spec)
# use `declarative` to convert dygraph into static program
model.forward = dygraph.jit.declarative(
model.forward, input_spec=input_vars)
outputs = model.forward.concrete_program.outputs
input_spec = [input_vars[i] for i in feed]
configs = dygraph.jit.SaveLoadConfig()
configs.separate_params = True
......
......@@ -27,13 +27,12 @@ import types
import numpy
import six
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
program_translator = ProgramTranslator()
to_static_func = program_translator.get_func
def is_builtin(func):
......@@ -63,7 +62,7 @@ def is_paddle_func(func):
def convert_call(func):
"""
Converts a function call which needs to be transformed to static fucntion.
Converts a function call which needs to be transformed to static function.
Args:
func (callable): A callable function or method to convert.
......@@ -98,6 +97,15 @@ def convert_call(func):
func_self = None
converted_call = None
# Function in convert_call may be decorated by another `@declarative`,
# in this case, unwraps it into a raw method or function.
if isinstance(func, StaticLayer):
instance = func._class_instance
if instance is not None:
func = func.dygraph_function.__get__(instance)
else:
func = func.dygraph_function
if is_builtin_len(func):
return convert_len
......@@ -109,11 +117,27 @@ def convert_call(func):
if func.__name__ == '<lambda>':
return func
try:
global_funcs = set([
fn for fn in func.__globals__.values() if inspect.isfunction(fn)
])
if func in global_funcs:
converted_call = to_static_func(func)
# Note(Aurelius84): Because `@declarative` returns a class instance instead of
# a function. This will modify the value referring to itself in `__globals__`.
# For example:
#
# @declarative
# def foo(x):
# return x
#
# `foo` will be converted into a wrapper class, suppose as `StaticLayer`.
# And `foo.__globals__['foo']` will still return this `StaticLayer` instead of
# `foo` function. So `isinstance(fn, StaticLayer)` is added here.
global_functions = set()
for fn in func.__globals__.values():
if inspect.isfunction(fn):
global_functions.add(fn)
elif isinstance(fn, StaticLayer):
global_functions.add(fn.dygraph_function)
if func in global_functions:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
except AttributeError:
# NOTE:
......@@ -127,7 +151,7 @@ def convert_call(func):
converted_call = None
elif inspect.ismethod(func):
try:
converted_call = to_static_func(func)
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError):
# NOTE: func may have been decorated.
......@@ -136,7 +160,7 @@ def convert_call(func):
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
forward_func = to_static_func(func.forward)
forward_func = convert_to_static(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
except Exception:
......@@ -146,7 +170,7 @@ def convert_call(func):
else:
try:
call_func = func.__class__.__call__
converted_call = to_static_func(call_func)
converted_call = convert_to_static(call_func)
func_self = func
except Exception:
# NOTE:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import six
import inspect
import numpy as np
import collections
import paddle
from paddle.fluid import core
from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs
from paddle.fluid.dygraph.dygraph_to_static.utils import type_name
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
class FunctionSpec(object):
"""
Wrapper class for a function for class method.
"""
def __init__(self, function, input_spec=None):
self._dygraph_function = function
if input_spec is None:
self._input_spec = None
self._flat_input_spec = None
else:
self._input_spec = self._verify_input_spec(input_spec)
self._flat_input_spec = flatten(self._input_spec)
# parse full argument names list.
self._arg_names, self._default_kwargs = parse_arg_and_kwargs(function)
def unified_args_and_kwargs(self, args, kwargs):
"""
Moves kwargs with default value into arguments list to keep `args` contain the same length
value as function definition.
For example:
Given function definition: `def foo(x, a=1, b=2)`,
when calling it by `foo(23)`, the args is `[23]`, kwargs is `{a=1, b=2}`.
In this function, it will return args with `[23, 1, 2]`, kwargs with `{}`
Args:
args(tuple): tuple of input arguments value of decorated function.
kwargs(dict): dict of input keyword arguments value of decorated function.
Return:
New arguments tuple containing default kwargs value.
"""
if len(self._arg_names) < len(args):
error_msg = "The decorated function `{}` requires {} arguments: {}, but received {} with {}.".format(
self._dygraph_function.__name__,
len(self._arg_names), self._arg_names, len(args), args)
if args and inspect.isclass(args[0]):
error_msg += "\n\tMaybe the function has more than one decorator, we don't support this for now."
raise NotImplementedError(error_msg)
else:
raise ValueError(error_msg)
args = list(args)
for i in six.moves.range(len(args), len(self._arg_names)):
arg_name = self._arg_names[i]
if arg_name in kwargs:
args.append(kwargs[arg_name])
del kwargs[arg_name]
else:
if arg_name not in self._default_kwargs:
raise ValueError(
"`{}()` requires `{}` arguments, but not found in input `args`: {} and `kwargs`: {}.".
format(self._dygraph_function.__name__, arg_name, args,
kwargs))
args.append(self._default_kwargs[arg_name])
return tuple(args), kwargs
def args_to_input_spec(self, args, kwargs):
"""
Converts input arguments into InputSpec.
1. If specific input_spec, use them to construct feed layers.
2. If input_spec is None, consider all Tensor and Numpy.ndarray as feed layers
Args:
args(tuple): tuple of input arguments value of function containing default kwargs value.
kwargs(dict): kwargs arguments received by **kwargs.
Return:
Same nest structure with args by replacing value with InputSpec.
"""
input_with_spec = []
if self._input_spec is not None:
# Note: Because the value type and length of `kwargs` is uncertain.
# So we don't support to deal this case while specificing `input_spec` currently.
if kwargs:
raise ValueError(
"{} got unexpected keyword arguments: {}. Cannot trace the function when `input_spec` is specificed.".
format(self._dygraph_function.__name__, kwargs))
# Note: The length of `input_spec` can be greater than `args`,
# because `args` may contains non-tensor value merged form `kwargs`
# after `unified_args_and_kwargs`.
if len(args) < len(self._input_spec):
raise ValueError(
"Requires len(arguments) >= len(input_spec), but received len(args):{} < len(InputSpec): {}".
format(len(args), len(self._input_spec)))
# replace argument with corresponding InputSpec.
input_with_spec = convert_to_input_spec(args, self._input_spec)
else:
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
elif isinstance(input_var, core.VarBase):
input_var = paddle.static.InputSpec.from_tensor(input_var)
input_with_spec.append(input_var)
input_with_spec = pack_sequence_as(args, input_with_spec)
return input_with_spec
@switch_to_static_graph
def to_static_inputs_with_spec(self, input_with_spec, main_program):
"""
Constructs feed layer by inputs with InputSpec information for main program.
Args:
input_with_spec(tuple): input arguments by replacing argument with InputSpec.
main_program(Program): main program for inserting feed layer.
"""
flat_input_spec = flatten(input_with_spec)
inputs = []
block = main_program.global_block()
for i, var_spec in enumerate(flat_input_spec):
if isinstance(var_spec, paddle.static.InputSpec):
feed_layer = block.create_var(
# TODO(Aurelius84): consider a more elegant way to name this
name=var_spec.name or "feed_%s" % i,
shape=var_spec.shape,
dtype=var_spec.dtype,
is_data=True,
need_check_feed=False)
else:
feed_layer = var_spec
inputs.append(feed_layer)
return pack_sequence_as(input_with_spec, inputs)
def _verify_input_spec(self, input_spec):
"""
Verifies the `input_spec` and its element type is valid.
"""
if not isinstance(input_spec, (tuple, list)):
raise TypeError(
"The type(input_spec) should be one of (tuple, list), but received {}.".
format(type_name(input_spec)))
input_spec = tuple(input_spec)
for spec in flatten(input_spec):
if not isinstance(spec, paddle.static.InputSpec):
raise ValueError(
"The type(elem) from input_spec should be `InputSpec`, but received {}.".
format(type_name(spec)))
return input_spec
def __repr__(self):
return "function: {}({}), input_spec: {}".format(
self._dygraph_function.__name__, ','.join(self._arg_names),
self._input_spec)
@property
def dygraph_function(self):
return self._dygraph_function
@property
def args_name(self):
return self._arg_names
@property
def input_spec(self):
return self._input_spec
@property
def flat_input_spec(self):
return self._flat_input_spec
@property
def code(self):
return func_to_source_code(self._dygraph_function)
def get_parameters(layer_instance, include_sublayer=True):
"""
Returns parameters of decorated layers. If set `include_sublayer` True,
the parameters created in sub layers will be added.
"""
params = collections.OrderedDict()
if layer_instance is not None:
if isinstance(layer_instance, layers.Layer):
if include_sublayer:
params = layer_instance.parameters()
names = [p.name for p in params]
params = collections.OrderedDict(zip(names, params))
else:
params = layer_instance._parameters
else:
raise TypeError(
"Type of `layer_instance` should be nn.Layer, but received {}".
format(type_name(layer_instance)))
return params
def get_buffers(layer_instance, include_sublayer=True):
"""
Returns Variable buffers of decorated layers. If set `include_sublayer` True,
the Variable buffers created in sub layers will be added.
"""
buffers = collections.OrderedDict()
if layer_instance is not None:
if isinstance(layer_instance, layers.Layer):
if include_sublayer:
buffers = layer_instance.buffers()
names = [buffer.name for buffer in buffers]
buffers = collections.OrderedDict(zip(names, buffers))
else:
buffers = layer_instance._buffers
else:
raise TypeError(
"Type of `layer_instance` should be nn.Layer, but received {}".
format(type_name(layer_instance)))
return buffers
def convert_to_input_spec(inputs, input_spec):
"""
Replaces tensor in structured `inputs` by InputSpec in `input_spec`.
Args:
inputs(list|dict): nested structure list or dict.
input_spec(list|dict): same nested structure list or dict as inputs.
Return:
Same structure with inputs by replacing the element with specified InputSpec.
"""
def check_type_and_len(input, spec, check_length=False):
if type(input) is not type(spec):
raise TypeError('type(input) should be {}, but received {}.'.format(
type(spec), type(input)))
if check_length and len(input) < len(spec):
raise ValueError(
'Requires len(inputs) >= len(input_spec), but received len(inputs):{} < len(input_spec):{}'.
format(len(inputs), len(input_spec)))
if isinstance(input_spec, (tuple, list)):
input_with_spec = []
check_type_and_len(inputs, input_spec, True)
for i, spec in enumerate(input_spec):
out_spec = convert_to_input_spec(inputs[i], spec)
input_with_spec.append(out_spec)
# Note: If the rest inputs contain tensor or numpy.ndarray
# without specific InputSpec, raise warning.
if len(inputs) > len(input_spec):
for rest_input in inputs[len(input_spec):]:
if isinstance(rest_input, (core.VarBase, np.ndarray)):
logging.warning(
"The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
"Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.".
format(type_name(rest_input)))
input_with_spec.extend(inputs[len(input_spec):])
return input_with_spec
elif isinstance(input_spec, dict):
input_with_spec = {}
check_type_and_len(inputs, input_spec, True)
for name, input in inputs.items():
if name in input_spec:
input_with_spec[name] = convert_to_input_spec(input,
input_spec[name])
else:
input_with_spec[name] = input
return input_with_spec
elif isinstance(input_spec, paddle.static.InputSpec):
return input_spec
else:
raise TypeError(
"The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.".
type_name(input_spec))
......@@ -18,12 +18,14 @@ import ast
import astor
import atexit
import copy
import collections
import gast
import inspect
import os
import six
import tempfile
import textwrap
import numpy as np
from paddle.fluid import unique_name
......@@ -46,6 +48,77 @@ dygraph_class_to_static_api = {
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
# FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2.
FullArgSpec = collections.namedtuple('FullArgSpec', [
'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults',
'annotations'
])
def getfullargspec(target):
if hasattr(inspect, "getfullargspec"):
return inspect.getfullargspec(target)
else:
argspec = inspect.getargspec(target)
return FullArgSpec(
args=argspec.args,
varargs=argspec.varargs,
varkw=argspec.keywords,
defaults=argspec.defaults,
kwonlyargs=[],
kwonlydefaults=None,
annotations={})
def parse_arg_and_kwargs(function):
"""
Returns full argument names as list. e.g ['x', 'y', 'z']
"""
fullargspec = getfullargspec(function)
arg_names = fullargspec.args
if arg_names and 'self' == arg_names[0]:
arg_names = fullargspec.args[1:]
# parse default kwargs
default_kwargs = {}
default_values = fullargspec.defaults
if default_values:
assert len(default_values) <= len(arg_names)
default_kwarg_names = arg_names[-len(default_values):]
default_kwargs = dict(zip(default_kwarg_names, default_values))
return arg_names, default_kwargs
def type_name(v):
return type(v).__name__
def make_hashable(x, error_msg=None):
"""
Makes input `x` hashable.
For some unhashable objects, such as `dict/list/np.ndarray`,applying hash function by using their values.
"""
if isinstance(x, (tuple, list)):
return tuple(map(make_hashable, x))
try:
hash(x)
except TypeError:
if isinstance(x, np.ndarray):
# Note: `tostring()` will return the binary data from np.ndarray that
# means different value will lead to different hash code.
return hash(x.tostring())
elif isinstance(x, dict):
return tuple(map(make_hashable, x.values()))
error_msg = error_msg or "Requires a hashable object."
raise ValueError(error_msg + " But received type: %s" % type_name(x))
return x
def _is_api_in_module_helper(obj, module_prefix):
m = inspect.getmodule(obj)
......
......@@ -378,7 +378,7 @@ def _load_persistable_vars_by_program(model_path,
new_var = framework._varbase_creator(
type=each_var.type(),
name=each_var.name(),
shpae=each_var.shape(),
shape=each_var.shape(),
dtype=each_var.dtype(),
persistable=True)
if params_filename is None:
......@@ -636,7 +636,7 @@ class TranslatedLayer(layers.Layer):
)
if not isinstance(persistable_vars, dict):
raise TypeError(
"TranslatedLayer need to use persisatbale variable dict for initialization."
"TranslatedLayer need to use persistable variable dict for initialization."
)
self._program_holder_dict = programs
......@@ -685,7 +685,7 @@ class TranslatedLayer(layers.Layer):
# 1. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path, model_filename)
# 2. load layer parameters & parameter attirbutes
# 2. load layer parameters & parameter attributes
persistable_vars = _construct_params_and_buffers(
model_path, programs, separate_params, params_filename)
......@@ -753,7 +753,7 @@ class TranslatedLayer(layers.Layer):
core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(program_holder.scope)
# 2. run prorgam by op
# 2. run program by op
trace_program = program_holder.infer_program if self._is_test else program_holder.train_program
end_op_index = program_holder.infer_program.block(0).op_size()
framework._dygraph_tracer().trace_op(
......@@ -774,7 +774,7 @@ class TranslatedLayer(layers.Layer):
# will be SelectedRows, not LoDTensor. But tracer will just
# set param grad VarBase by forward VarBase(LoDTensor)
# If we don't change grad_var type here, RunProgramOp need
# transform SelectedRows to LoDTensor forcely, it may not
# transform SelectedRows to LoDTensor forcibly, it may not
# be user wanted result.
for persistable_var in persistable_vars:
grad_var_name = var.name + core.grad_var_suffix()
......
......@@ -19,12 +19,12 @@ import pickle
import warnings
import six
import paddle
from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
from paddle.fluid.dygraph.dygraph_to_static.program_translator import FunctionSpec, ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticLayer, unwrap_decorators
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard
......@@ -128,7 +128,27 @@ def _dygraph_to_static_func_(dygraph_func):
dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_)
def _declarative_(dygraph_func):
def copy_decorator_attrs(original_func, decorated_obj):
"""
Copies some necessary attributes from original function into decorated function.
Args:
original_func(callable): the original decorated function.
decorated_obj(StaticLayer): the target decorated StaticLayer object.
"""
decorator_name = "declarative"
decorated_obj.__name__ = original_func.__name__
decorated_obj._decorator_name = decorator_name
decorated_obj.__wrapped__ = original_func
decorated_obj.__doc__ = original_func.__doc__
if hasattr(original_func, "__module__"):
decorated_obj.__module__ = original_func.__module__
return decorated_obj
def declarative(function=None, input_spec=None):
"""
Converts imperative dygraph APIs into declarative function APIs. Decorator
@declarative handles the Program and Executor of static mode and returns
......@@ -138,7 +158,9 @@ def _declarative_(dygraph_func):
converted into declarative function as well.
Args:
dygraph_func (callable): callable imperative function.
function (callable): callable imperative function.
input_spec(list[InputSpec]): list of InputSpec to specific the shape/dtype/name
information of each input Tensor.
Returns:
Tensor(s): containing the numerical result.
......@@ -167,37 +189,27 @@ def _declarative_(dygraph_func):
"""
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
if not program_translator.enable_declarative:
warnings.warn(
"The decorator 'declarative' doesn't work when setting ProgramTranslator.enable=False. "
"We will just return dygraph output.")
return dygraph_func(*args, **kwargs)
try:
return program_translator.get_output(dygraph_func, *args, **kwargs)
except Exception as e:
error_data = getattr(e, ERROR_DATA, None)
if error_data:
new_exception = error_data.create_exception()
if six.PY3:
# NOTE(liym27):
# 1. Why `raise new_exception from None`?
# In Python 3, by default, an new exception is raised with trace information of the caught exception.
# This only raises new_exception and hides unwanted implementation details from tracebacks of the
# caught exception.
# 2. Use exec to bypass syntax error checking in Python 2.
six.exec_("raise new_exception from None")
else:
raise new_exception
else:
raise
def decorated(python_func):
"""
Decorates a python function into a StaticLayer object.
"""
# Step 1. unwrap the function if it is already decorated.
_, python_func = unwrap_decorators(python_func)
return __impl__
# Step 2. copy some attributes from original python function.
static_layer = copy_decorator_attrs(
original_func=python_func,
decorated_obj=StaticLayer(
function=python_func, input_spec=input_spec))
return static_layer
# for usage: `declarative(foo, ...)`
if function is not None:
return decorated(function)
declarative = wrap_decorator(_declarative_)
# for usage: `@declarative`
return decorated
class SaveLoadConfig(object):
......@@ -339,7 +351,7 @@ class SaveLoadConfig(object):
# use SaveLoadconfig.output_spec
model_path = "simplenet.example.model.output_spec"
configs = fluid.dygraph.jit.SaveLoadConfig()
# only keep the predicted output in saved model, diccard loss
# only keep the predicted output in saved model, discard loss
configs.output_spec = [out]
fluid.dygraph.jit.save(
......@@ -374,7 +386,7 @@ class SaveLoadConfig(object):
The name of file to save the translated program of target Layer.
Default filename is :code:`__model__` .
Exampels:
Examples:
.. code-block:: python
import numpy as np
......@@ -444,7 +456,7 @@ class SaveLoadConfig(object):
The name of file to save all persistable variables in target Layer.
Default file name is :code:`__variables__` .
Exampels:
Examples:
.. code-block:: python
import numpy as np
......@@ -597,7 +609,7 @@ def save(layer, model_path, input_spec=None, configs=None):
The default saved translated program file name is ``__model__``,
and the default saved persistable variables file name is ``__variables__``,
and it also saved some additional variable description information to file
``__varibales.info__``, these additional information is used in fine-tuning.
``__variables.info__``, these additional information is used in fine-tuning.
The saved model can be loaded by follow APIs:
- :ref:`api_imperative_jit_load`
......@@ -607,7 +619,7 @@ def save(layer, model_path, input_spec=None, configs=None):
Args:
layer (Layer): the Layer to be saved. The Layer should be decorated by `@declarative`.
model_path (str): the directory to save the model.
input_spec (list[Varibale], optional): Describes the input of the saved model.
input_spec (list[Variable], optional): Describes the input of the saved model.
It is the example inputs that will be passed to saved TranslatedLayer's forward
function. If None, all input variables of the original Layer's forward function
would be the inputs of the saved model. Default None.
......@@ -721,16 +733,17 @@ def save(layer, model_path, input_spec=None, configs=None):
"The input input_spec should be 'list', but received input_spec's type is %s."
% type(input_spec))
for var in input_spec:
if not isinstance(var, core.VarBase):
if not isinstance(var, (core.VarBase, Variable,
paddle.static.InputSpec)):
raise TypeError(
"The element in input_spec list should be 'Variable', but received element's type is %s."
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
% type(var))
# 2. get program of declarative Layer.forward
prog_cache = prog_translator.get_program_cache()
# make dummy args & kwargs, to get excepted FunctionSpec
layer_func = FunctionSpec(type(layer).forward, [layer], {})
concrete_program, _ = prog_cache.get_program(layer_func)
if not isinstance(layer.forward, StaticLayer):
raise RuntimeError(
"layer.forward need to be decorated by `@declarative`.")
concrete_program = layer.forward.concrete_program
# NOTE: we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
......@@ -814,7 +827,7 @@ def load(model_path, configs=None):
For some historical reasons, if you load model saved by :ref:`api_fluid_io_save_inference_model`,
there will be the following limitations when using it in fine-tuning:
1. Imperative mode do not support LoDTensor. All original model's feed targets or parametars that depend on LoD are temporarily unavailable.
2. All saved model's feed targets need to be passed into TranslatedLayer's forwrad function.
2. All saved model's feed targets need to be passed into TranslatedLayer's forward function.
3. The variable's ``stop_gradient`` information is lost and can not be recovered.
4. The parameter's ``trainable`` information is lost and can not be recovered.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle.static import InputSpec
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit
import unittest
program_trans = ProgramTranslator()
class SimpleNet(Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.linear = fluid.dygraph.Linear(10, 3)
@declarative(input_spec=[InputSpec(shape=[None, 10], dtype='float32')])
def forward(self, x, a=1, b=2):
y = self.inner_function(x)
return y
# `declarative` is not essential, add it to test for robustness.
@declarative
def inner_function(self, x):
y = self.linear(x)
return y
def add_func(self, x, y):
z = x + y
return z
@declarative(input_spec=[[InputSpec([None, 10]), InputSpec([None, 10])]])
def func_with_list(self, l):
x, y, int_val = l
z = x + y
z = z + int_val
return z
@declarative(input_spec=[{
'x': InputSpec([None, 10]),
'y': InputSpec([None, 10])
}])
def func_with_dict(self, d):
x = d['x']
y = d['y']
int_val = d['int_val']
z = x + y
z = z + int_val
return z
@declarative(input_spec=[[
InputSpec([None]), {
'x': InputSpec([None, 10]),
'y': InputSpec([None, 10])
}
]])
def func_with_list_dict(self, dl):
bias = dl[0]
x = dl[1]['x']
y = dl[1]['y']
z = x + y
z = z + bias
return z
class TestInputSpec(unittest.TestCase):
def setUp(self):
pass
def test_with_input_spec(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
x = to_variable(np.ones([4, 10]).astype('float32'))
y = to_variable(np.ones([4, 10]).astype('float32') * 2)
int_val = 4.
net = SimpleNet()
# 1. each method holds independent program cache
out = net(x)
self.assertTrue(len(net.forward.program_cache) == 1)
# 2. test save load
jit.save(net, './simple_net')
infer_net = fluid.dygraph.jit.load('./simple_net')
pred = infer_net(x)
self.assertTrue(np.allclose(out.numpy(), pred.numpy()))
# 3. we can decorate any method
x_2 = to_variable(np.ones([4, 20]).astype('float32'))
# uses `declarative(func)` instead of `@declarative`
net.add_func = declarative(net.add_func)
out = net.add_func(x_2, np.ones([20]).astype('float32'))
self.assertTrue(len(net.add_func.program_cache) == 1)
# 5. test input with list
out = net.func_with_list([x, y, int_val])
# 6. test input with dict
out = net.func_with_dict({'x': x, 'y': y, 'int_val': int_val})
# 7. test input with lits contains dict
int_np = np.ones([1]).astype('float32')
out = net.func_with_list_dict([int_np, {'x': x, 'y': y}])
def test_with_error(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
x = to_variable(np.ones([4, 10]).astype('float32'))
y = to_variable(np.ones([4, 10]).astype('float32') * 2)
int_val = 4.
net = SimpleNet()
# 1. kwargs and input_spec should not be specificed in same time
with self.assertRaises(ValueError):
net(x, a=1, other_kwarg=2)
# 2. requires len(input_spec) <= len(args)
with self.assertRaises(ValueError):
net.add_func = declarative(
net.add_func,
input_spec=[
InputSpec([-1, 10]), InputSpec([-1, 10]),
InputSpec([10])
])
net.add_func(x, y)
def test_concrete_program(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
x = to_variable(np.ones([4, 10]).astype('float32'))
y = to_variable(np.ones([4, 10]).astype('float32') * 2)
int_val = 4.
net = SimpleNet()
# We can get concrete_program by specificing InputSpec information. Faking input is no need.
net.add_func = declarative(
net.add_func,
input_spec=[
InputSpec([-1, 10]), InputSpec(
[-1, 10], name='y')
])
cp1 = net.add_func.concrete_program
self.assertTrue(cp1.inputs[-1].shape == (-1, 10))
self.assertTrue(cp1.inputs[-1].name == 'y')
# generate another program
net.add_func = declarative(
net.add_func,
input_spec=[InputSpec([10]), InputSpec(
[10], name='label')])
cp2 = net.add_func.concrete_program
self.assertTrue(cp2.inputs[-1].shape == (10, ))
self.assertTrue(cp2.inputs[-1].name == 'label')
# Note(Aurelius84): New instance will be returned if we use `declarative(foo)` every time.
# So number of cache program is 1.
self.assertTrue(len(net.add_func.program_cache) == 1)
self.assertTrue(cp1 != cp2)
def foo_func(a, b, c=1, d=2):
z = a + b
return z
class TestDifferentInputSpecCacheProgram(unittest.TestCase):
def test_with_different_input(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
x_data = np.ones([16, 10]).astype('float32')
y_data = np.ones([10]).astype('float32') * 2
z_data = np.ones([10]).astype('float32') * 2.2
foo = declarative(foo_func)
# [16, 10] + [10] (varbase)
out_1 = foo(to_variable(x_data), to_variable(y_data))
self.assertTrue(np.allclose(x_data + y_data, out_1.numpy()))
self.assertTrue(len(foo.program_cache) == 1)
# [16, 10] + [10] (numpy)
out_2 = foo(to_variable(x_data), y_data)
self.assertTrue(np.allclose(x_data + y_data, out_2.numpy()))
self.assertTrue(len(foo.program_cache) == 1)
# [16, 10] + [10] (numpy)
out_3 = foo(to_variable(x_data), z_data)
self.assertTrue(np.allclose(x_data + z_data, out_3.numpy()))
# hit cache program
self.assertTrue(len(foo.program_cache) == 1)
# [16, 10] + [10] (numpy) with other different arguments (c=3)
out_4 = foo(to_variable(x_data), z_data, 3)
self.assertTrue(np.allclose(x_data + z_data, out_4.numpy()))
# create a new program
self.assertTrue(len(foo.program_cache) == 2)
def test_get_concrete_program(self):
foo = declarative(foo_func)
# 1. specific InputSpec for `x`/`y`
concrete_program_1 = foo.get_concrete_program(
InputSpec([None, 10]), InputSpec([10]))
print(concrete_program_1)
self.assertTrue(len(foo.program_cache) == 1)
# 2. specific `c`/`d` explicitly with same default value
concrete_program_2 = foo.get_concrete_program(
InputSpec([None, 10]), InputSpec([10]), 1, 2)
self.assertTrue(concrete_program_2 == concrete_program_1)
self.assertTrue(len(foo.program_cache) == 1)
# 3. specific `c` = 2
concrete_program_3 = foo.get_concrete_program(
InputSpec([None, 10]), InputSpec([10]), c=2)
self.assertTrue(concrete_program_3 != concrete_program_1)
self.assertTrue(len(foo.program_cache) == 2)
# 4. specific x.shape = [10]
concrete_program_4 = foo.get_concrete_program(
InputSpec([10]), InputSpec([10]))
self.assertTrue(concrete_program_4 != concrete_program_1)
self.assertTrue(len(foo.program_cache) == 3)
# 5. only specific InputSpec of x
with self.assertRaises(ValueError):
concrete_program_5 = foo.get_concrete_program(InputSpec([10]))
# 6. specific unknown kwargs `e`=4
concrete_program_5 = foo.get_concrete_program(
InputSpec([10]), InputSpec([10]), e=4)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.static import InputSpec
from paddle.fluid.dygraph.dygraph_to_static.function_spec import FunctionSpec
from test_declarative import foo_func
import unittest
class TestFunctionSpec(unittest.TestCase):
def test_constructor(self):
foo_spec = FunctionSpec(foo_func)
args_name = foo_spec.args_name
self.assertListEqual(args_name, ['a', 'b', 'c', 'd'])
self.assertTrue(foo_spec.dygraph_function == foo_func)
self.assertTrue(foo_spec.input_spec is None)
def test_verify_input_spec(self):
a_spec = InputSpec([None, 10], name='a')
b_spec = InputSpec([10], name='b')
# type(input_spec) should be list or tuple
with self.assertRaises(TypeError):
foo_spec = FunctionSpec(foo_func, input_spec=a_spec)
# each element of input_spec should be `InputSpec`
with self.assertRaises(ValueError):
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, 10])
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
self.assertTrue(len(foo_spec.flat_input_spec) == 2)
def test_unified_args_and_kwargs(self):
foo_spec = FunctionSpec(foo_func)
# case 1: foo(10, 20, c=4)
args, kwargs = foo_spec.unified_args_and_kwargs([10, 20], {'c': 4})
self.assertTupleEqual(args, (10, 20, 4, 2))
self.assertTrue(len(kwargs) == 0)
# case 2: foo(a=10, b=20, d=4)
args, kwargs = foo_spec.unified_args_and_kwargs(
[], {'a': 10,
'b': 20,
'd': 4})
self.assertTupleEqual(args, (10, 20, 1, 4))
self.assertTrue(len(kwargs) == 0)
# case 3: foo(10, b=20)
args, kwargs = foo_spec.unified_args_and_kwargs([10], {'b': 20})
self.assertTupleEqual(args, (10, 20, 1, 2))
self.assertTrue(len(kwargs) == 0)
# assert len(self._arg_names) >= len(args)
with self.assertRaises(ValueError):
foo_spec.unified_args_and_kwargs([10, 20, 30, 40, 50], {'c': 4})
# assert arg_name should be in kwargs
with self.assertRaises(ValueError):
foo_spec.unified_args_and_kwargs([10], {'c': 4})
def test_args_to_input_spec(self):
a_spec = InputSpec([None, 10], name='a')
b_spec = InputSpec([10], name='b')
a_tensor = paddle.static.data(name='a_var', shape=[4, 10])
b_tensor = paddle.static.data(name='b_var', shape=[4, 10])
kwargs = {'c': 1, 'd': 2}
# case 1
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
input_with_spec = foo_spec.args_to_input_spec(
(a_tensor, b_tensor, 1, 2), {})
self.assertTrue(len(input_with_spec) == 4)
self.assertTrue(input_with_spec[0] == a_spec) # a
self.assertTrue(input_with_spec[1] == b_spec) # b
self.assertTrue(input_with_spec[2] == 1) # c
self.assertTrue(input_with_spec[3] == 2) # d
# case 2
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
input_with_spec = foo_spec.args_to_input_spec((a_tensor, b_tensor), {})
self.assertTrue(len(input_with_spec) == 2)
self.assertTrue(input_with_spec[0] == a_spec) # a
self.assertTupleEqual(input_with_spec[1].shape, (4, 10)) # b.shape
self.assertEqual(input_with_spec[1].name, 'b_var') # b.name
# case 3
# assert kwargs is None if set `input_spec`
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
with self.assertRaises(ValueError):
input_with_spec = foo_spec.args_to_input_spec((a_tensor, b_tensor),
{'c': 4})
# case 4
# assert len(args) >= len(self._input_spec)
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
with self.assertRaises(ValueError):
input_with_spec = foo_spec.args_to_input_spec((a_tensor, ), {})
if __name__ == '__main__':
unittest.main()
......@@ -133,7 +133,7 @@ class TestWithTrainAndEval(unittest.TestCase):
x = fluid.dygraph.to_variable(x_data)
linear_net(x)
_, partial_layer = program_translator.get_program_cache().last()[-1]
_, partial_layer = linear_net.forward.program_cache.last()[-1]
# check default mode is for training
self.assertEqual(partial_layer.program,
partial_layer._train_program)
......
......@@ -133,7 +133,7 @@ class TestPartialProgramRaiseError(unittest.TestCase):
x = fluid.dygraph.to_variable(x_data)
out = net(x)
program_cache = program_translator.get_program_cache()
program_cache = SimpleFcLayer.forward.program_cache
_, (concrete_program, _) = program_cache.last()
params = concrete_program.parameters
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.static import InputSpec
from paddle.fluid.framework import core, convert_np_dtype_to_dtype_
class TestInputSpec(unittest.TestCase):
def test_default(self):
tensor_spec = InputSpec([3, 4])
self.assertEqual(tensor_spec.dtype,
convert_np_dtype_to_dtype_('float32'))
self.assertEqual(tensor_spec.name, None)
def test_from_tensor(self):
x_bool = fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)
bool_spec = InputSpec.from_tensor(x_bool)
self.assertEqual(bool_spec.dtype, x_bool.dtype)
self.assertEqual(bool_spec.shape, x_bool.shape)
self.assertEqual(bool_spec.name, x_bool.name)
bool_spec2 = InputSpec.from_tensor(x_bool, name='bool_spec')
self.assertEqual(bool_spec2.name, bool_spec2.name)
def test_from_numpy(self):
x_numpy = np.ones([10, 12])
x_np_spec = InputSpec.from_numpy(x_numpy)
self.assertEqual(x_np_spec.dtype,
convert_np_dtype_to_dtype_(x_numpy.dtype))
self.assertEqual(x_np_spec.shape, x_numpy.shape)
self.assertEqual(x_np_spec.name, None)
x_numpy2 = np.array([1, 2, 3, 4]).astype('int64')
x_np_spec2 = InputSpec.from_numpy(x_numpy2, name='x_np_int64')
self.assertEqual(x_np_spec2.dtype,
convert_np_dtype_to_dtype_(x_numpy2.dtype))
self.assertEqual(x_np_spec2.shape, x_numpy2.shape)
self.assertEqual(x_np_spec2.name, 'x_np_int64')
def test_shape_with_none(self):
tensor_spec = InputSpec([None, 4, None], dtype='int8', name='x_spec')
self.assertEqual(tensor_spec.dtype, convert_np_dtype_to_dtype_('int8'))
self.assertEqual(tensor_spec.name, 'x_spec')
self.assertEqual(tensor_spec.shape, (-1, 4, -1))
def test_shape_raise_error(self):
# 1. shape should only contain int and None.
with self.assertRaises(ValueError):
tensor_spec = InputSpec(['None', 4, None], dtype='int8')
# 2. shape should be type `list` or `tuple`
with self.assertRaises(TypeError):
tensor_spec = InputSpec(4, dtype='int8')
# 3. len(shape) should be greater than 0.
with self.assertRaises(ValueError):
tensor_spec = InputSpec([], dtype='int8')
def test_batch_and_unbatch(self):
tensor_spec = InputSpec([10])
# insert batch_size
batch_tensor_spec = tensor_spec.batch(16)
self.assertEqual(batch_tensor_spec.shape, (16, 10))
# unbatch
unbatch_spec = batch_tensor_spec.unbatch()
self.assertEqual(unbatch_spec.shape, (10, ))
# 1. `unbatch` requires len(shape) > 1
with self.assertRaises(ValueError):
unbatch_spec.unbatch()
# 2. `batch` requires len(batch_size) == 1
with self.assertRaises(ValueError):
tensor_spec.batch([16, 12])
# 3. `batch` requires type(batch_size) == int
with self.assertRaises(TypeError):
tensor_spec.batch('16')
def test_eq_and_hash(self):
tensor_spec_1 = InputSpec([10, 16], dtype='float32')
tensor_spec_2 = InputSpec([10, 16], dtype='float32')
tensor_spec_3 = InputSpec([10, 16], dtype='float32', name='x')
tensor_spec_4 = InputSpec([16], dtype='float32', name='x')
# override ``__eq__`` according to [shape, dtype, name]
self.assertTrue(tensor_spec_1 == tensor_spec_2)
self.assertTrue(tensor_spec_1 != tensor_spec_3) # different name
self.assertTrue(tensor_spec_3 != tensor_spec_4) # different shape
# override ``__hash__`` according to [shape, dtype]
self.assertTrue(hash(tensor_spec_1) == hash(tensor_spec_2))
self.assertTrue(hash(tensor_spec_1) == hash(tensor_spec_3))
self.assertTrue(hash(tensor_spec_3) != hash(tensor_spec_4))
if __name__ == '__main__':
unittest.main()
......@@ -19,11 +19,11 @@ import pickle
import unittest
import numpy as np
import paddle
from paddle.static import InputSpec
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME
BATCH_SIZE = 32
BATCH_NUM = 10
......@@ -156,7 +156,7 @@ class TestJitSaveLoad(unittest.TestCase):
def load_dygraph_state_dict(self, train_layer):
train_layer.eval()
# contruct new model
# construct new model
new_layer = LinearNet(784, 1)
model_dict, _ = fluid.dygraph.load_dygraph(self.model_path)
new_layer.set_dict(model_dict)
......@@ -176,7 +176,7 @@ class TestJitSaveLoad(unittest.TestCase):
model_path=self.model_path,
input_spec=example_inputs)
def test_load_dygraoh_no_path(self):
def test_load_dygraph_no_path(self):
model_path = "model.test_jit_save_load.no_path"
new_layer = LinearNet(784, 1)
with self.assertRaises(ValueError):
......@@ -202,6 +202,92 @@ class TestJitSaveLoad(unittest.TestCase):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)
class LinearNetMultiInput(fluid.dygraph.Layer):
def __init__(self, in_size, out_size):
super(LinearNetMultiInput, self).__init__()
self._linear1 = Linear(in_size, out_size)
# self._linear2 = Linear(in_size, out_size)
@declarative(input_spec=[
InputSpec(
[None, 8], dtype='float32'), InputSpec(
[None, 8], dtype='float32')
])
def forward(self, x, y):
x_out = self._linear1(x)
y_out = self._linear1(y)
loss = fluid.layers.mean(x_out + y_out)
return x_out, y_out, loss
class TestSaveLoadWithInputSpec(unittest.TestCase):
def setUp(self):
# enable dygraph mode
fluid.enable_dygraph()
def test_with_input_spec(self):
net = LinearNetReturnLoss(8, 8)
# set x.shape = [None, 8]
net.forward = declarative(
net.forward, input_spec=[InputSpec(
[None, 8], name='x')])
model_path = "model.input_spec.output_spec"
configs = fluid.dygraph.jit.SaveLoadConfig()
# check inputs and outputs
self.assertTrue(len(net.forward.inputs) == 1)
input_x = net.forward.inputs[0]
self.assertTrue(input_x.shape == (-1, 8))
self.assertTrue(input_x.name == 'x')
# 1. prune loss
configs.output_spec = net.forward.outputs[:1]
fluid.dygraph.jit.save(net, model_path, configs=configs)
# 2. load to infer
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs)
x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32'))
pred = infer_layer(x)
def test_multi_in_out(self):
net = LinearNetMultiInput(8, 8)
model_path = "model.multi_inout.output_spec1"
configs = fluid.dygraph.jit.SaveLoadConfig()
# 1. check inputs and outputs
self.assertTrue(len(net.forward.inputs) == 2)
input_x = net.forward.inputs[0]
input_y = net.forward.inputs[1]
self.assertTrue(input_x.shape == (-1, 8))
self.assertTrue(input_y.shape == (-1, 8))
# 2. prune loss
configs.output_spec = net.forward.outputs[:2]
fluid.dygraph.jit.save(net, model_path, configs=configs)
# 3. load to infer
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs)
x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32'))
y = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32'))
# 4. predict
pred_x, pred_y = infer_layer(x, y)
# 1. prune y and loss
model_path = "model.multi_inout.output_spec2"
configs.output_spec = net.forward.outputs[:1]
fluid.dygraph.jit.save(net, model_path, [input_x], configs)
# 2. load again
infer_layer2 = fluid.dygraph.jit.load(model_path, configs=configs)
# 3. predict
pred_xx = infer_layer2(x)
# 4. assert pred_x == pred_xx
self.assertTrue(np.allclose(pred_x.numpy(), pred_xx.numpy()))
class TestJitSaveLoadConfig(unittest.TestCase):
def setUp(self):
# enable dygraph mode
......
......@@ -1607,10 +1607,7 @@ class Model(object):
% type(layer))
# 2. get program of declarative Layer.forward
prog_cache = prog_translator.get_program_cache()
# make dummy args & kwargs, to get excepted FunctionSpec
layer_func = FunctionSpec(type(layer).forward, [layer], {})
concrete_program, _ = prog_cache.get_program(layer_func)
concrete_program = layer.forward.concrete_program
# NOTE: we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
......@@ -1742,12 +1739,13 @@ class Model(object):
out_specs = []
if specs is None:
# If not specific specs of `Input`, using argument names of `forward` function
# to generate `Input`.
# Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function
# to generate `Input`. But how can we know the actual shape of each input tensor?
if is_input:
out_specs = [
Input(name=n) for n in extract_args(self.network.forward)
if n != 'self'
Input(
name=n, shape=[None])
for n in extract_args(self.network.forward) if n != 'self'
]
else:
out_specs = to_list(specs)
......
......@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import six
from paddle.fluid import core
import paddle
from paddle.fluid import core, Variable
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_dtype, check_type
from paddle.fluid.data_feeder import check_type
from paddle.fluid.framework import convert_np_dtype_to_dtype_
__all__ = ['data', 'InputSpec']
......@@ -41,7 +41,7 @@ def data(name, shape, dtype=None, lod_level=0):
size. For example, it is useful to set changeable batch size as "None" or -1.
dtype (np.dtype|str, optional): The type of the data. Supported
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
uint8. Default: None. When `dtype` is not set, the dtype will get
uint8. Default: None. When `dtype` is not set, the dtype will get
from the global dtype by `paddle.get_default_dtype()`.
lod_level (int, optional): The LoD level of the LoDTensor. Usually users
don't have to set this value. For more details about when and how to
......@@ -54,13 +54,12 @@ def data(name, shape, dtype=None, lod_level=0):
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
import paddle
# Creates a variable with fixed size [3, 2, 1]
# User can only feed data of the same shape to x
# the dtype is not set, so it will set "float32" by
# paddle.get_default_dtype(). You can use paddle.get_default_dtype() to
# paddle.get_default_dtype(). You can use paddle.get_default_dtype() to
# change the global dtype
x = paddle.static.data(name='x', shape=[3, 2, 1])
......@@ -75,8 +74,8 @@ def data(name, shape, dtype=None, lod_level=0):
# and fetch z, like implementing "1 + 1 = 2" in PaddlePaddle
feed_data = np.ones(shape=[3, 2, 1], dtype=np.float32)
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(fluid.default_main_program(),
exe = paddle.static.Executor(paddle.framework.CPUPlace())
out = exe.run(paddle.static.default_main_program(),
feed={
'x': feed_data,
'y': feed_data
......@@ -120,11 +119,13 @@ def data(name, shape, dtype=None, lod_level=0):
class InputSpec(object):
"""
Define input specification of the model.
InputSpec describes the signature information of the model input, such as ``shape`` , ``dtype`` , ``name`` .
This interface is often used to specify input tensor information of models in high-level API.
It's also used to specify the tensor information for each input parameter of the forward function
decorated by `@paddle.jit.to_static`.
Args:
name (str): The name/alias of the variable, see :ref:`api_guide_Name`
for more details.
shape (tuple(integers)|list[integers]): List|Tuple of integers
declaring the shape. You can set "None" or -1 at a dimension
to indicate the dimension can be of any size. For example,
......@@ -132,18 +133,28 @@ class InputSpec(object):
dtype (np.dtype|str, optional): The type of the data. Supported
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
uint8. Default: float32.
name (str): The name/alias of the variable, see :ref:`api_guide_Name`
for more details.
Examples:
.. code-block:: python
from paddle.static import InputSpec
from paddle.static import InputSpec
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
print(input) # InputSpec(shape=(-1, 784), dtype=VarType.FP32, name=x)
print(label) # InputSpec(shape=(-1, 1), dtype=VarType.INT64, name=label)
"""
def __init__(self, shape=None, dtype='float32', name=None):
self.shape = shape
def __init__(self, shape, dtype='float32', name=None):
# replace `None` in shape with -1
self.shape = self._verify(shape)
# convert dtype into united represention
if dtype is not None:
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
self.dtype = dtype
self.name = name
......@@ -153,3 +164,167 @@ class InputSpec(object):
def __repr__(self):
return '{}(shape={}, dtype={}, name={})'.format(
type(self).__name__, self.shape, self.dtype, self.name)
@classmethod
def from_tensor(cls, tensor, name=None):
"""
Generates a InputSpec based on the description of input tensor.
Args:
tensor(Tensor): the source tensor to generate a InputSpec instance
Returns:
A InputSpec instance generated from Tensor.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.static import InputSpec
paddle.disable_static()
x = paddle.to_tensor(np.ones([2, 2], np.float32))
x_spec = InputSpec.from_tensor(x, name='x')
print(x_spec) # InputSpec(shape=(2, 2), dtype=VarType.FP32, name=x)
"""
if isinstance(tensor, (Variable, core.VarBase)):
return cls(tensor.shape, tensor.dtype, name or tensor.name)
else:
raise ValueError(
"Input `tensor` should be a Tensor, but received {}.".format(
type(tensor).__name__))
@classmethod
def from_numpy(cls, ndarray, name=None):
"""
Generates a InputSpec based on the description of input np.ndarray.
Args:
tensor(Tensor): the source numpy ndarray to generate a InputSpec instance
Returns:
A InputSpec instance generated from Tensor.
Examples:
.. code-block:: python
import numpy as np
from paddle.static import InputSpec
x = np.ones([2, 2], np.float32)
x_spec = InputSpec.from_numpy(x, name='x')
print(x_spec) # InputSpec(shape=(2, 2), dtype=VarType.FP32, name=x)
"""
return cls(ndarray.shape, ndarray.dtype, name)
def batch(self, batch_size):
"""
Inserts `batch_size` in front of the `shape`.
Args:
batch_size(int): the inserted integer value of batch size.
Returns:
The original InputSpec instance by inserting `batch_size` in front of `shape`.
Examples:
.. code-block:: python
from paddle.static import InputSpec
x_spec = InputSpec(shape=[64], dtype='float32', name='x')
x_spec.batch(4)
print(x_spec) # InputSpec(shape=(4, 64), dtype=VarType.FP32, name=x)
"""
if isinstance(batch_size, (list, tuple)):
if len(batch_size) != 1:
raise ValueError(
"Length of batch_size: {} shall be 1, but received {}.".
format(batch_size, len(batch_size)))
batch_size = batch_size[1]
elif not isinstance(batch_size, six.integer_types):
raise TypeError("type(batch_size) shall be `int`, but received {}.".
format(type(batch_size).__name__))
new_shape = [batch_size] + list(self.shape)
self.shape = tuple(new_shape)
return self
def unbatch(self):
"""
Removes the first element of `shape`.
Returns:
The original InputSpec instance by removing the first element of `shape` .
Examples:
.. code-block:: python
from paddle.static import InputSpec
x_spec = InputSpec(shape=[4, 64], dtype='float32', name='x')
x_spec.unbatch()
print(x_spec) # InputSpec(shape=(64,), dtype=VarType.FP32, name=x)
"""
if len(self.shape) == 0:
raise ValueError(
"Not support to unbatch a InputSpec when len(shape) == 0.")
self.shape = self._verify(self.shape[1:])
return self
def _verify(self, shape):
"""
Verifies the input shape and modifies `None` into `-1`.
"""
if not isinstance(shape, (list, tuple)):
raise TypeError(
"Type of `shape` in InputSpec should be one of (tuple, list), but received {}.".
format(type(shape).__name__))
if len(shape) == 0:
raise ValueError(
"`shape` in InputSpec should contain at least 1 element, but received {}.".
format(shape))
for i, ele in enumerate(shape):
if ele is not None:
if not isinstance(ele, six.integer_types):
raise ValueError(
"shape[{}] should be an `int`, but received `{}`:{}.".
format(i, type(ele).__name__, ele))
if ele is None or ele < -1:
shape[i] = -1
return tuple(shape)
def __hash__(self):
# Note(Aurelius84): `name` is not considered as a field to compute hashkey.
# Because it's no need to generate a new program in following cases while using
# @paddle.jit.to_static.
#
# Case 1:
# foo(x_var)
# foo(y_var)
# x_var and y_var hold same shape and dtype, they should share a same program.
#
#
# Case 2:
# foo(x_var)
# foo(x_np) # x_np is a numpy.ndarray.
# x_var and x_np hold same shape and dtype, they should also share a same program.
return hash((tuple(self.shape), self.dtype))
def __eq__(self, other):
slots = ['shape', 'dtype', 'name']
return (type(self) is type(other) and all(
getattr(self, attr) == getattr(other, attr) for attr in slots))
def __ne__(self, other):
return not self == other
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册