diff --git a/keras/__init__.py b/keras/__init__.py index 1c8e66f99dd6ed9292a553e91994ee39845a9198..345f530137094185253f2b9043eef9b101ed36ca 100644 --- a/keras/__init__.py +++ b/keras/__init__.py @@ -1,6 +1,5 @@ from __future__ import absolute_import -from . import utils from . import activations from . import applications from . import backend @@ -8,6 +7,7 @@ from . import datasets from . import engine from . import layers from . import preprocessing +from . import utils from . import wrappers from . import callbacks from . import constraints diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index c8b8a2042e7761afa1fc6a2b97034b8273502dfe..ec6b637207d0fc55031812598b708a2ba66acaf2 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -7,14 +7,13 @@ from tensorflow.python.ops import ctc_ops as ctc from tensorflow.python.ops import variables as tf_variables from collections import defaultdict - +import inspect import numpy as np import os from .common import floatx from .common import _EPSILON from .common import image_data_format -from ..utils.generic_utils import has_arg # Legacy functions from .common import set_image_dim_ordering @@ -2286,7 +2285,8 @@ def function(inputs, outputs, updates=None, **kwargs): """ if kwargs: for key in kwargs: - if not (has_arg(tf.Session.run, key, True) or has_arg(Function.__init__, key, True)): + if (key not in inspect.getargspec(tf.Session.run)[0] and + key not in inspect.getargspec(Function.__init__)[0]): msg = 'Invalid argument "%s" passed to K.function with Tensorflow backend' % key raise ValueError(msg) return Function(inputs, outputs, updates=updates, **kwargs) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index da88e3ad0917d39b8b1ada0413912f78bfedb2cc..17b8d85c4035eddbf0888f6259fae79ebd78f776 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -14,10 +14,9 @@ try: from theano.tensor.nnet.nnet import softsign as T_softsign except ImportError: from theano.sandbox.softsign import softsign as T_softsign - +import inspect import numpy as np from .common import _FLOATX, floatx, _EPSILON, image_data_format -from ..utils.generic_utils import has_arg # Legacy functions from .common import set_image_dim_ordering, image_dim_ordering @@ -1195,8 +1194,9 @@ class Function(object): def function(inputs, outputs, updates=[], **kwargs): if len(kwargs) > 0: + function_args = inspect.getargspec(theano.function)[0] for key in kwargs.keys(): - if not has_arg(theano.function, key, True): + if key not in function_args: msg = 'Invalid argument "%s" passed to K.function with Theano backend' % key raise ValueError(msg) return Function(inputs, outputs, updates=updates, **kwargs) diff --git a/keras/engine/topology.py b/keras/engine/topology.py index f7084aad111318749d67c1f481842148dd8e8752..c7450351deca0f89fbe58d190852b0f56e8dbec1 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -10,13 +10,13 @@ import warnings import copy import os import re +import inspect from six.moves import zip from .. import backend as K from .. import initializers from ..utils.io_utils import ask_to_proceed_with_overwrite from ..utils.layer_utils import print_summary as print_layer_summary -from ..utils.generic_utils import has_arg from ..utils import conv_utils from ..legacy import interfaces @@ -584,7 +584,7 @@ class Layer(object): user_kwargs = copy.copy(kwargs) if not _is_all_none(previous_mask): # The previous layer generated a mask. - if has_arg(self.call, 'mask'): + if 'mask' in inspect.getargspec(self.call).args: if 'mask' not in kwargs: # If mask is explicitly passed to __call__, # we should override the default mask. @@ -2206,7 +2206,7 @@ class Container(Layer): kwargs = {} if len(computed_data) == 1: computed_tensor, computed_mask = computed_data[0] - if has_arg(layer.call, 'mask'): + if 'mask' in inspect.getargspec(layer.call).args: if 'mask' not in kwargs: kwargs['mask'] = computed_mask output_tensors = _to_list(layer.call(computed_tensor, **kwargs)) @@ -2217,7 +2217,7 @@ class Container(Layer): else: computed_tensors = [x[0] for x in computed_data] computed_masks = [x[1] for x in computed_data] - if has_arg(layer.call, 'mask'): + if 'mask' in inspect.getargspec(layer.call).args: if 'mask' not in kwargs: kwargs['mask'] = computed_masks output_tensors = _to_list(layer.call(computed_tensors, **kwargs)) diff --git a/keras/layers/core.py b/keras/layers/core.py index 8923273421497e1539bdb5b7ec5a0152aa9c2f91..23f100b78f8afc10a47e857292e17cc92b847139 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -5,6 +5,7 @@ from __future__ import division import numpy as np import copy +import inspect import types as python_types import warnings @@ -18,7 +19,6 @@ from ..engine import Layer from ..utils.generic_utils import func_dump from ..utils.generic_utils import func_load from ..utils.generic_utils import deserialize_keras_object -from ..utils.generic_utils import has_arg from ..legacy import interfaces @@ -642,7 +642,8 @@ class Lambda(Layer): def call(self, inputs, mask=None): arguments = self.arguments - if has_arg(self.function, 'mask'): + arg_spec = inspect.getargspec(self.function) + if 'mask' in arg_spec.args: arguments['mask'] = mask return self.function(inputs, **arguments) diff --git a/keras/layers/wrappers.py b/keras/layers/wrappers.py index 8042927ec66107ad964747f15c3fb5247de53a07..d30ec12da96d771ea5bdadba6221b7e898cd218c 100644 --- a/keras/layers/wrappers.py +++ b/keras/layers/wrappers.py @@ -2,9 +2,9 @@ from __future__ import absolute_import import copy +import inspect from ..engine import Layer from ..engine import InputSpec -from ..utils.generic_utils import has_arg from .. import backend as K @@ -272,9 +272,10 @@ class Bidirectional(Wrapper): def call(self, inputs, training=None, mask=None): kwargs = {} - if has_arg(self.layer.call, 'training'): + func_args = inspect.getargspec(self.layer.call).args + if 'training' in func_args: kwargs['training'] = training - if has_arg(self.layer.call, 'mask'): + if 'mask' in func_args: kwargs['mask'] = mask y = self.forward_layer.call(inputs, **kwargs) diff --git a/keras/legacy/layers.py b/keras/legacy/layers.py index 1e53fc53b54551a19fd4f6be6459b62dc0166987..bbbb8f221759b893b41c0de508e8e965928f47eb 100644 --- a/keras/legacy/layers.py +++ b/keras/legacy/layers.py @@ -1,9 +1,10 @@ +import inspect import types as python_types import warnings from ..engine.topology import Layer, InputSpec from .. import backend as K -from ..utils.generic_utils import func_dump, func_load, has_arg +from ..utils.generic_utils import func_dump, func_load from .. import regularizers from .. import constraints from .. import activations @@ -196,7 +197,8 @@ class Merge(Layer): # Case: "mode" is a lambda or function. if callable(self.mode): arguments = self.arguments - if has_arg(self.mode, 'mask'): + arg_spec = inspect.getargspec(self.mode) + if 'mask' in arg_spec.args: arguments['mask'] = mask return self.mode(inputs, **arguments) diff --git a/keras/utils/__init__.py b/keras/utils/__init__.py index 245159fa5fad436bc072429c91ba32d9d79e1579..793dea19a2f08a858840cedba90432694d937ca2 100644 --- a/keras/utils/__init__.py +++ b/keras/utils/__init__.py @@ -1,9 +1,9 @@ from __future__ import absolute_import from . import np_utils -from . import generic_utils +from . import conv_utils from . import data_utils +from . import generic_utils from . import io_utils -from . import conv_utils # Globally-importable utils. from .io_utils import HDF5Matrix diff --git a/keras/utils/generic_utils.py b/keras/utils/generic_utils.py index 30a65f052ca4fc5cf40fd9b2410745d98b57ffe7..76477d5ac3864426ee6a8bc755618efdfb77e24b 100644 --- a/keras/utils/generic_utils.py +++ b/keras/utils/generic_utils.py @@ -132,8 +132,9 @@ def deserialize_keras_object(identifier, module_objects=None, raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) if hasattr(cls, 'from_config'): + arg_spec = inspect.getargspec(cls.from_config) custom_objects = custom_objects or {} - if has_arg(cls.from_config, 'custom_objects'): + if 'custom_objects' in arg_spec.args: return cls.from_config(config['config'], custom_objects=dict(list(_GLOBAL_CUSTOM_OBJECTS.items()) + list(custom_objects.items()))) @@ -206,48 +207,6 @@ def func_load(code, defaults=None, closure=None, globs=None): closure=closure) -def has_arg(fn, name, accept_all=False): - """Checks if a callable accepts a given keyword argument. - - For Python 2, checks if there is an argument with the given name. - - For Python 3, checks if there is an argument with the given name, and - also whether this argument can be called with a keyword (i.e. if it is - not a positional-only argument). - - # Arguments - fn: Callable to inspect. - name: Check if `fn` can be called with `name` as a keyword argument. - accept_all: What to return if there is no parameter called `name` - but the function accepts a `**kwargs` argument. - - # Returns - bool, whether `fn` accepts a `name` keyword argument. - """ - if sys.version_info < (3,): - arg_spec = inspect.getargspec(fn) - if accept_all and arg_spec.keywords is not None: - return True - return (name in arg_spec.args) - elif sys.version_info < (3, 3): - arg_spec = inspect.getfullargspec(fn) - if accept_all and arg_spec.varkw is not None: - return True - return (name in arg_spec.args or - name in arg_spec.kwonlyargs) - else: - signature = inspect.signature(fn) - parameter = signature.parameters.get(name) - if parameter is None: - if accept_all: - for param in signature.parameters.values(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - return True - return False - return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY)) - - class Progbar(object): """Displays a progress bar. diff --git a/keras/utils/test_utils.py b/keras/utils/test_utils.py index 9c59891d8e7d564efd0431995d609c84c4a91ad5..014fd9c4b9576841615333c42aaeb9b6b574b70a 100644 --- a/keras/utils/test_utils.py +++ b/keras/utils/test_utils.py @@ -1,9 +1,9 @@ """Utilities related to Keras unit tests.""" import numpy as np from numpy.testing import assert_allclose +import inspect import six -from .generic_utils import has_arg from ..engine import Model, Input from ..models import Sequential from ..models import model_from_json @@ -71,9 +71,7 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None, layer.set_weights(weights) # test and instantiation from weights - # Checking for empty weights array to avoid a problem where some - # legacy layers return bad values from get_weights() - if has_arg(layer_cls.__init__, 'weights') and len(weights): + if 'weights' in inspect.getargspec(layer_cls.__init__): kwargs['weights'] = weights layer = layer_cls(**kwargs) diff --git a/keras/wrappers/scikit_learn.py b/keras/wrappers/scikit_learn.py index 7a79eaa5ce9815c6af78cf31daa9c5620c706140..7aa0036d4b9e6f000fa24b3953be98816ddd4f3b 100644 --- a/keras/wrappers/scikit_learn.py +++ b/keras/wrappers/scikit_learn.py @@ -1,12 +1,12 @@ from __future__ import absolute_import import copy +import inspect import types import numpy as np from ..utils.np_utils import to_categorical -from ..utils.generic_utils import has_arg from ..models import Sequential @@ -75,11 +75,13 @@ class BaseWrapper(object): else: legal_params_fns.append(self.build_fn) + legal_params = [] + for fn in legal_params_fns: + legal_params += inspect.getargspec(fn)[0] + legal_params = set(legal_params) + for params_name in params: - for fn in legal_params_fns: - if has_arg(fn, params_name): - break - else: + if params_name not in legal_params: if params_name != 'nb_epoch': raise ValueError( '{} is not a legal parameter'.format(params_name)) @@ -161,8 +163,9 @@ class BaseWrapper(object): """ override = override or {} res = {} + fn_args = inspect.getargspec(fn)[0] for name, value in self.sk_params.items(): - if has_arg(fn, name): + if name in fn_args: res.update({name: value}) res.update(override) return res diff --git a/tests/keras/utils/generic_utils_test.py b/tests/keras/utils/generic_utils_test.py index f29f89cbf65a01d6a423e7fd00f7fdf4f4f28e6a..2bc71c48aec531b176a3f20444c261201e0dcca5 100644 --- a/tests/keras/utils/generic_utils_test.py +++ b/tests/keras/utils/generic_utils_test.py @@ -1,6 +1,5 @@ -import sys import pytest -from keras.utils.generic_utils import custom_object_scope, has_arg +from keras.utils.generic_utils import custom_object_scope from keras import activations from keras import regularizers @@ -21,46 +20,5 @@ def test_custom_objects_scope(): assert cl.__class__ == CustomClass -@pytest.mark.parametrize('fn, name, accept_all, expected', [ - ('f(x)', 'x', False, True), - ('f(x)', 'y', False, False), - ('f(x)', 'y', True, False), - ('f(x, y)', 'y', False, True), - ('f(x, y=1)', 'y', False, True), - ('f(x, **kwargs)', 'x', False, True), - ('f(x, **kwargs)', 'y', False, False), - ('f(x, **kwargs)', 'y', True, True), - ('f(x, y=1, **kwargs)', 'y', False, True), - # Keyword-only arguments (Python 3 only) - ('f(x, *args, y=1)', 'y', False, True), - ('f(x, *args, y=1)', 'z', True, False), - ('f(x, *, y=1)', 'x', False, True), - ('f(x, *, y=1)', 'y', False, True), - # lambda - (lambda x: x, 'x', False, True), - (lambda x: x, 'y', False, False), - (lambda x: x, 'y', True, False), -]) -def test_has_arg(fn, name, accept_all, expected): - if isinstance(fn, str): - context = dict() - try: - exec('def {}: pass'.format(fn), context) - except SyntaxError: - if sys.version_info >= (3,): - raise - pytest.skip('Function is not compatible with Python 2') - context.pop('__builtins__', None) # Sometimes exec adds builtins to the context - fn, = context.values() - - assert has_arg(fn, name, accept_all) is expected - - -@pytest.mark.xfail(sys.version_info < (3, 3), - reason='inspect API does not reveal positional-only arguments') -def test_has_arg_positional_only(): - assert has_arg(pow, 'x') is False - - if __name__ == '__main__': pytest.main([__file__])