提交 4135aeeb 编写于 作者: F François Chollet 提交者: GitHub

Revert "Avoid DeprecationWarning from inspect.getargspec (#6817)" (#7018)

This reverts commit ced84c4b.
上级 ced84c4b
from __future__ import absolute_import
from . import utils
from . import activations
from . import applications
from . import backend
......@@ -8,6 +7,7 @@ from . import datasets
from . import engine
from . import layers
from . import preprocessing
from . import utils
from . import wrappers
from . import callbacks
from . import constraints
......
......@@ -7,14 +7,13 @@ from tensorflow.python.ops import ctc_ops as ctc
from tensorflow.python.ops import variables as tf_variables
from collections import defaultdict
import inspect
import numpy as np
import os
from .common import floatx
from .common import _EPSILON
from .common import image_data_format
from ..utils.generic_utils import has_arg
# Legacy functions
from .common import set_image_dim_ordering
......@@ -2286,7 +2285,8 @@ def function(inputs, outputs, updates=None, **kwargs):
"""
if kwargs:
for key in kwargs:
if not (has_arg(tf.Session.run, key, True) or has_arg(Function.__init__, key, True)):
if (key not in inspect.getargspec(tf.Session.run)[0] and
key not in inspect.getargspec(Function.__init__)[0]):
msg = 'Invalid argument "%s" passed to K.function with Tensorflow backend' % key
raise ValueError(msg)
return Function(inputs, outputs, updates=updates, **kwargs)
......
......@@ -14,10 +14,9 @@ try:
from theano.tensor.nnet.nnet import softsign as T_softsign
except ImportError:
from theano.sandbox.softsign import softsign as T_softsign
import inspect
import numpy as np
from .common import _FLOATX, floatx, _EPSILON, image_data_format
from ..utils.generic_utils import has_arg
# Legacy functions
from .common import set_image_dim_ordering, image_dim_ordering
......@@ -1195,8 +1194,9 @@ class Function(object):
def function(inputs, outputs, updates=[], **kwargs):
if len(kwargs) > 0:
function_args = inspect.getargspec(theano.function)[0]
for key in kwargs.keys():
if not has_arg(theano.function, key, True):
if key not in function_args:
msg = 'Invalid argument "%s" passed to K.function with Theano backend' % key
raise ValueError(msg)
return Function(inputs, outputs, updates=updates, **kwargs)
......
......@@ -10,13 +10,13 @@ import warnings
import copy
import os
import re
import inspect
from six.moves import zip
from .. import backend as K
from .. import initializers
from ..utils.io_utils import ask_to_proceed_with_overwrite
from ..utils.layer_utils import print_summary as print_layer_summary
from ..utils.generic_utils import has_arg
from ..utils import conv_utils
from ..legacy import interfaces
......@@ -584,7 +584,7 @@ class Layer(object):
user_kwargs = copy.copy(kwargs)
if not _is_all_none(previous_mask):
# The previous layer generated a mask.
if has_arg(self.call, 'mask'):
if 'mask' in inspect.getargspec(self.call).args:
if 'mask' not in kwargs:
# If mask is explicitly passed to __call__,
# we should override the default mask.
......@@ -2206,7 +2206,7 @@ class Container(Layer):
kwargs = {}
if len(computed_data) == 1:
computed_tensor, computed_mask = computed_data[0]
if has_arg(layer.call, 'mask'):
if 'mask' in inspect.getargspec(layer.call).args:
if 'mask' not in kwargs:
kwargs['mask'] = computed_mask
output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
......@@ -2217,7 +2217,7 @@ class Container(Layer):
else:
computed_tensors = [x[0] for x in computed_data]
computed_masks = [x[1] for x in computed_data]
if has_arg(layer.call, 'mask'):
if 'mask' in inspect.getargspec(layer.call).args:
if 'mask' not in kwargs:
kwargs['mask'] = computed_masks
output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
......
......@@ -5,6 +5,7 @@ from __future__ import division
import numpy as np
import copy
import inspect
import types as python_types
import warnings
......@@ -18,7 +19,6 @@ from ..engine import Layer
from ..utils.generic_utils import func_dump
from ..utils.generic_utils import func_load
from ..utils.generic_utils import deserialize_keras_object
from ..utils.generic_utils import has_arg
from ..legacy import interfaces
......@@ -642,7 +642,8 @@ class Lambda(Layer):
def call(self, inputs, mask=None):
arguments = self.arguments
if has_arg(self.function, 'mask'):
arg_spec = inspect.getargspec(self.function)
if 'mask' in arg_spec.args:
arguments['mask'] = mask
return self.function(inputs, **arguments)
......
......@@ -2,9 +2,9 @@
from __future__ import absolute_import
import copy
import inspect
from ..engine import Layer
from ..engine import InputSpec
from ..utils.generic_utils import has_arg
from .. import backend as K
......@@ -272,9 +272,10 @@ class Bidirectional(Wrapper):
def call(self, inputs, training=None, mask=None):
kwargs = {}
if has_arg(self.layer.call, 'training'):
func_args = inspect.getargspec(self.layer.call).args
if 'training' in func_args:
kwargs['training'] = training
if has_arg(self.layer.call, 'mask'):
if 'mask' in func_args:
kwargs['mask'] = mask
y = self.forward_layer.call(inputs, **kwargs)
......
import inspect
import types as python_types
import warnings
from ..engine.topology import Layer, InputSpec
from .. import backend as K
from ..utils.generic_utils import func_dump, func_load, has_arg
from ..utils.generic_utils import func_dump, func_load
from .. import regularizers
from .. import constraints
from .. import activations
......@@ -196,7 +197,8 @@ class Merge(Layer):
# Case: "mode" is a lambda or function.
if callable(self.mode):
arguments = self.arguments
if has_arg(self.mode, 'mask'):
arg_spec = inspect.getargspec(self.mode)
if 'mask' in arg_spec.args:
arguments['mask'] = mask
return self.mode(inputs, **arguments)
......
from __future__ import absolute_import
from . import np_utils
from . import generic_utils
from . import conv_utils
from . import data_utils
from . import generic_utils
from . import io_utils
from . import conv_utils
# Globally-importable utils.
from .io_utils import HDF5Matrix
......
......@@ -132,8 +132,9 @@ def deserialize_keras_object(identifier, module_objects=None,
raise ValueError('Unknown ' + printable_module_name +
': ' + class_name)
if hasattr(cls, 'from_config'):
arg_spec = inspect.getargspec(cls.from_config)
custom_objects = custom_objects or {}
if has_arg(cls.from_config, 'custom_objects'):
if 'custom_objects' in arg_spec.args:
return cls.from_config(config['config'],
custom_objects=dict(list(_GLOBAL_CUSTOM_OBJECTS.items()) +
list(custom_objects.items())))
......@@ -206,48 +207,6 @@ def func_load(code, defaults=None, closure=None, globs=None):
closure=closure)
def has_arg(fn, name, accept_all=False):
"""Checks if a callable accepts a given keyword argument.
For Python 2, checks if there is an argument with the given name.
For Python 3, checks if there is an argument with the given name, and
also whether this argument can be called with a keyword (i.e. if it is
not a positional-only argument).
# Arguments
fn: Callable to inspect.
name: Check if `fn` can be called with `name` as a keyword argument.
accept_all: What to return if there is no parameter called `name`
but the function accepts a `**kwargs` argument.
# Returns
bool, whether `fn` accepts a `name` keyword argument.
"""
if sys.version_info < (3,):
arg_spec = inspect.getargspec(fn)
if accept_all and arg_spec.keywords is not None:
return True
return (name in arg_spec.args)
elif sys.version_info < (3, 3):
arg_spec = inspect.getfullargspec(fn)
if accept_all and arg_spec.varkw is not None:
return True
return (name in arg_spec.args or
name in arg_spec.kwonlyargs)
else:
signature = inspect.signature(fn)
parameter = signature.parameters.get(name)
if parameter is None:
if accept_all:
for param in signature.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD:
return True
return False
return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY))
class Progbar(object):
"""Displays a progress bar.
......
"""Utilities related to Keras unit tests."""
import numpy as np
from numpy.testing import assert_allclose
import inspect
import six
from .generic_utils import has_arg
from ..engine import Model, Input
from ..models import Sequential
from ..models import model_from_json
......@@ -71,9 +71,7 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
layer.set_weights(weights)
# test and instantiation from weights
# Checking for empty weights array to avoid a problem where some
# legacy layers return bad values from get_weights()
if has_arg(layer_cls.__init__, 'weights') and len(weights):
if 'weights' in inspect.getargspec(layer_cls.__init__):
kwargs['weights'] = weights
layer = layer_cls(**kwargs)
......
from __future__ import absolute_import
import copy
import inspect
import types
import numpy as np
from ..utils.np_utils import to_categorical
from ..utils.generic_utils import has_arg
from ..models import Sequential
......@@ -75,11 +75,13 @@ class BaseWrapper(object):
else:
legal_params_fns.append(self.build_fn)
legal_params = []
for fn in legal_params_fns:
legal_params += inspect.getargspec(fn)[0]
legal_params = set(legal_params)
for params_name in params:
for fn in legal_params_fns:
if has_arg(fn, params_name):
break
else:
if params_name not in legal_params:
if params_name != 'nb_epoch':
raise ValueError(
'{} is not a legal parameter'.format(params_name))
......@@ -161,8 +163,9 @@ class BaseWrapper(object):
"""
override = override or {}
res = {}
fn_args = inspect.getargspec(fn)[0]
for name, value in self.sk_params.items():
if has_arg(fn, name):
if name in fn_args:
res.update({name: value})
res.update(override)
return res
......
import sys
import pytest
from keras.utils.generic_utils import custom_object_scope, has_arg
from keras.utils.generic_utils import custom_object_scope
from keras import activations
from keras import regularizers
......@@ -21,46 +20,5 @@ def test_custom_objects_scope():
assert cl.__class__ == CustomClass
@pytest.mark.parametrize('fn, name, accept_all, expected', [
('f(x)', 'x', False, True),
('f(x)', 'y', False, False),
('f(x)', 'y', True, False),
('f(x, y)', 'y', False, True),
('f(x, y=1)', 'y', False, True),
('f(x, **kwargs)', 'x', False, True),
('f(x, **kwargs)', 'y', False, False),
('f(x, **kwargs)', 'y', True, True),
('f(x, y=1, **kwargs)', 'y', False, True),
# Keyword-only arguments (Python 3 only)
('f(x, *args, y=1)', 'y', False, True),
('f(x, *args, y=1)', 'z', True, False),
('f(x, *, y=1)', 'x', False, True),
('f(x, *, y=1)', 'y', False, True),
# lambda
(lambda x: x, 'x', False, True),
(lambda x: x, 'y', False, False),
(lambda x: x, 'y', True, False),
])
def test_has_arg(fn, name, accept_all, expected):
if isinstance(fn, str):
context = dict()
try:
exec('def {}: pass'.format(fn), context)
except SyntaxError:
if sys.version_info >= (3,):
raise
pytest.skip('Function is not compatible with Python 2')
context.pop('__builtins__', None) # Sometimes exec adds builtins to the context
fn, = context.values()
assert has_arg(fn, name, accept_all) is expected
@pytest.mark.xfail(sys.version_info < (3, 3),
reason='inspect API does not reveal positional-only arguments')
def test_has_arg_positional_only():
assert has_arg(pow, 'x') is False
if __name__ == '__main__':
pytest.main([__file__])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册