Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_40195168达庆意
keras
提交
4135aeeb
K
keras
项目概览
weixin_40195168达庆意
/
keras
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
K
keras
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
4135aeeb
编写于
6月 16, 2017
作者:
F
François Chollet
提交者:
GitHub
6月 16, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "Avoid DeprecationWarning from inspect.getargspec (#6817)" (#7018)
This reverts commit
ced84c4b
.
上级
ced84c4b
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
38 addition
and
116 deletion
+38
-116
keras/__init__.py
keras/__init__.py
+1
-1
keras/backend/tensorflow_backend.py
keras/backend/tensorflow_backend.py
+3
-3
keras/backend/theano_backend.py
keras/backend/theano_backend.py
+3
-3
keras/engine/topology.py
keras/engine/topology.py
+4
-4
keras/layers/core.py
keras/layers/core.py
+3
-2
keras/layers/wrappers.py
keras/layers/wrappers.py
+4
-3
keras/legacy/layers.py
keras/legacy/layers.py
+4
-2
keras/utils/__init__.py
keras/utils/__init__.py
+2
-2
keras/utils/generic_utils.py
keras/utils/generic_utils.py
+2
-43
keras/utils/test_utils.py
keras/utils/test_utils.py
+2
-4
keras/wrappers/scikit_learn.py
keras/wrappers/scikit_learn.py
+9
-6
tests/keras/utils/generic_utils_test.py
tests/keras/utils/generic_utils_test.py
+1
-43
未找到文件。
keras/__init__.py
浏览文件 @
4135aeeb
from
__future__
import
absolute_import
from
.
import
utils
from
.
import
activations
from
.
import
applications
from
.
import
backend
...
...
@@ -8,6 +7,7 @@ from . import datasets
from
.
import
engine
from
.
import
layers
from
.
import
preprocessing
from
.
import
utils
from
.
import
wrappers
from
.
import
callbacks
from
.
import
constraints
...
...
keras/backend/tensorflow_backend.py
浏览文件 @
4135aeeb
...
...
@@ -7,14 +7,13 @@ from tensorflow.python.ops import ctc_ops as ctc
from
tensorflow.python.ops
import
variables
as
tf_variables
from
collections
import
defaultdict
import
inspect
import
numpy
as
np
import
os
from
.common
import
floatx
from
.common
import
_EPSILON
from
.common
import
image_data_format
from
..utils.generic_utils
import
has_arg
# Legacy functions
from
.common
import
set_image_dim_ordering
...
...
@@ -2286,7 +2285,8 @@ def function(inputs, outputs, updates=None, **kwargs):
"""
if
kwargs
:
for
key
in
kwargs
:
if
not
(
has_arg
(
tf
.
Session
.
run
,
key
,
True
)
or
has_arg
(
Function
.
__init__
,
key
,
True
)):
if
(
key
not
in
inspect
.
getargspec
(
tf
.
Session
.
run
)[
0
]
and
key
not
in
inspect
.
getargspec
(
Function
.
__init__
)[
0
]):
msg
=
'Invalid argument "%s" passed to K.function with Tensorflow backend'
%
key
raise
ValueError
(
msg
)
return
Function
(
inputs
,
outputs
,
updates
=
updates
,
**
kwargs
)
...
...
keras/backend/theano_backend.py
浏览文件 @
4135aeeb
...
...
@@ -14,10 +14,9 @@ try:
from
theano.tensor.nnet.nnet
import
softsign
as
T_softsign
except
ImportError
:
from
theano.sandbox.softsign
import
softsign
as
T_softsign
import
inspect
import
numpy
as
np
from
.common
import
_FLOATX
,
floatx
,
_EPSILON
,
image_data_format
from
..utils.generic_utils
import
has_arg
# Legacy functions
from
.common
import
set_image_dim_ordering
,
image_dim_ordering
...
...
@@ -1195,8 +1194,9 @@ class Function(object):
def
function
(
inputs
,
outputs
,
updates
=
[],
**
kwargs
):
if
len
(
kwargs
)
>
0
:
function_args
=
inspect
.
getargspec
(
theano
.
function
)[
0
]
for
key
in
kwargs
.
keys
():
if
not
has_arg
(
theano
.
function
,
key
,
True
)
:
if
key
not
in
function_args
:
msg
=
'Invalid argument "%s" passed to K.function with Theano backend'
%
key
raise
ValueError
(
msg
)
return
Function
(
inputs
,
outputs
,
updates
=
updates
,
**
kwargs
)
...
...
keras/engine/topology.py
浏览文件 @
4135aeeb
...
...
@@ -10,13 +10,13 @@ import warnings
import
copy
import
os
import
re
import
inspect
from
six.moves
import
zip
from
..
import
backend
as
K
from
..
import
initializers
from
..utils.io_utils
import
ask_to_proceed_with_overwrite
from
..utils.layer_utils
import
print_summary
as
print_layer_summary
from
..utils.generic_utils
import
has_arg
from
..utils
import
conv_utils
from
..legacy
import
interfaces
...
...
@@ -584,7 +584,7 @@ class Layer(object):
user_kwargs
=
copy
.
copy
(
kwargs
)
if
not
_is_all_none
(
previous_mask
):
# The previous layer generated a mask.
if
has_arg
(
self
.
call
,
'mask'
)
:
if
'mask'
in
inspect
.
getargspec
(
self
.
call
).
args
:
if
'mask'
not
in
kwargs
:
# If mask is explicitly passed to __call__,
# we should override the default mask.
...
...
@@ -2206,7 +2206,7 @@ class Container(Layer):
kwargs
=
{}
if
len
(
computed_data
)
==
1
:
computed_tensor
,
computed_mask
=
computed_data
[
0
]
if
has_arg
(
layer
.
call
,
'mask'
)
:
if
'mask'
in
inspect
.
getargspec
(
layer
.
call
).
args
:
if
'mask'
not
in
kwargs
:
kwargs
[
'mask'
]
=
computed_mask
output_tensors
=
_to_list
(
layer
.
call
(
computed_tensor
,
**
kwargs
))
...
...
@@ -2217,7 +2217,7 @@ class Container(Layer):
else
:
computed_tensors
=
[
x
[
0
]
for
x
in
computed_data
]
computed_masks
=
[
x
[
1
]
for
x
in
computed_data
]
if
has_arg
(
layer
.
call
,
'mask'
)
:
if
'mask'
in
inspect
.
getargspec
(
layer
.
call
).
args
:
if
'mask'
not
in
kwargs
:
kwargs
[
'mask'
]
=
computed_masks
output_tensors
=
_to_list
(
layer
.
call
(
computed_tensors
,
**
kwargs
))
...
...
keras/layers/core.py
浏览文件 @
4135aeeb
...
...
@@ -5,6 +5,7 @@ from __future__ import division
import
numpy
as
np
import
copy
import
inspect
import
types
as
python_types
import
warnings
...
...
@@ -18,7 +19,6 @@ from ..engine import Layer
from
..utils.generic_utils
import
func_dump
from
..utils.generic_utils
import
func_load
from
..utils.generic_utils
import
deserialize_keras_object
from
..utils.generic_utils
import
has_arg
from
..legacy
import
interfaces
...
...
@@ -642,7 +642,8 @@ class Lambda(Layer):
def
call
(
self
,
inputs
,
mask
=
None
):
arguments
=
self
.
arguments
if
has_arg
(
self
.
function
,
'mask'
):
arg_spec
=
inspect
.
getargspec
(
self
.
function
)
if
'mask'
in
arg_spec
.
args
:
arguments
[
'mask'
]
=
mask
return
self
.
function
(
inputs
,
**
arguments
)
...
...
keras/layers/wrappers.py
浏览文件 @
4135aeeb
...
...
@@ -2,9 +2,9 @@
from
__future__
import
absolute_import
import
copy
import
inspect
from
..engine
import
Layer
from
..engine
import
InputSpec
from
..utils.generic_utils
import
has_arg
from
..
import
backend
as
K
...
...
@@ -272,9 +272,10 @@ class Bidirectional(Wrapper):
def
call
(
self
,
inputs
,
training
=
None
,
mask
=
None
):
kwargs
=
{}
if
has_arg
(
self
.
layer
.
call
,
'training'
):
func_args
=
inspect
.
getargspec
(
self
.
layer
.
call
).
args
if
'training'
in
func_args
:
kwargs
[
'training'
]
=
training
if
has_arg
(
self
.
layer
.
call
,
'mask'
)
:
if
'mask'
in
func_args
:
kwargs
[
'mask'
]
=
mask
y
=
self
.
forward_layer
.
call
(
inputs
,
**
kwargs
)
...
...
keras/legacy/layers.py
浏览文件 @
4135aeeb
import
inspect
import
types
as
python_types
import
warnings
from
..engine.topology
import
Layer
,
InputSpec
from
..
import
backend
as
K
from
..utils.generic_utils
import
func_dump
,
func_load
,
has_arg
from
..utils.generic_utils
import
func_dump
,
func_load
from
..
import
regularizers
from
..
import
constraints
from
..
import
activations
...
...
@@ -196,7 +197,8 @@ class Merge(Layer):
# Case: "mode" is a lambda or function.
if
callable
(
self
.
mode
):
arguments
=
self
.
arguments
if
has_arg
(
self
.
mode
,
'mask'
):
arg_spec
=
inspect
.
getargspec
(
self
.
mode
)
if
'mask'
in
arg_spec
.
args
:
arguments
[
'mask'
]
=
mask
return
self
.
mode
(
inputs
,
**
arguments
)
...
...
keras/utils/__init__.py
浏览文件 @
4135aeeb
from
__future__
import
absolute_import
from
.
import
np_utils
from
.
import
generic
_utils
from
.
import
conv
_utils
from
.
import
data_utils
from
.
import
generic_utils
from
.
import
io_utils
from
.
import
conv_utils
# Globally-importable utils.
from
.io_utils
import
HDF5Matrix
...
...
keras/utils/generic_utils.py
浏览文件 @
4135aeeb
...
...
@@ -132,8 +132,9 @@ def deserialize_keras_object(identifier, module_objects=None,
raise
ValueError
(
'Unknown '
+
printable_module_name
+
': '
+
class_name
)
if
hasattr
(
cls
,
'from_config'
):
arg_spec
=
inspect
.
getargspec
(
cls
.
from_config
)
custom_objects
=
custom_objects
or
{}
if
has_arg
(
cls
.
from_config
,
'custom_objects'
)
:
if
'custom_objects'
in
arg_spec
.
args
:
return
cls
.
from_config
(
config
[
'config'
],
custom_objects
=
dict
(
list
(
_GLOBAL_CUSTOM_OBJECTS
.
items
())
+
list
(
custom_objects
.
items
())))
...
...
@@ -206,48 +207,6 @@ def func_load(code, defaults=None, closure=None, globs=None):
closure
=
closure
)
def
has_arg
(
fn
,
name
,
accept_all
=
False
):
"""Checks if a callable accepts a given keyword argument.
For Python 2, checks if there is an argument with the given name.
For Python 3, checks if there is an argument with the given name, and
also whether this argument can be called with a keyword (i.e. if it is
not a positional-only argument).
# Arguments
fn: Callable to inspect.
name: Check if `fn` can be called with `name` as a keyword argument.
accept_all: What to return if there is no parameter called `name`
but the function accepts a `**kwargs` argument.
# Returns
bool, whether `fn` accepts a `name` keyword argument.
"""
if
sys
.
version_info
<
(
3
,):
arg_spec
=
inspect
.
getargspec
(
fn
)
if
accept_all
and
arg_spec
.
keywords
is
not
None
:
return
True
return
(
name
in
arg_spec
.
args
)
elif
sys
.
version_info
<
(
3
,
3
):
arg_spec
=
inspect
.
getfullargspec
(
fn
)
if
accept_all
and
arg_spec
.
varkw
is
not
None
:
return
True
return
(
name
in
arg_spec
.
args
or
name
in
arg_spec
.
kwonlyargs
)
else
:
signature
=
inspect
.
signature
(
fn
)
parameter
=
signature
.
parameters
.
get
(
name
)
if
parameter
is
None
:
if
accept_all
:
for
param
in
signature
.
parameters
.
values
():
if
param
.
kind
==
inspect
.
Parameter
.
VAR_KEYWORD
:
return
True
return
False
return
(
parameter
.
kind
in
(
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
,
inspect
.
Parameter
.
KEYWORD_ONLY
))
class
Progbar
(
object
):
"""Displays a progress bar.
...
...
keras/utils/test_utils.py
浏览文件 @
4135aeeb
"""Utilities related to Keras unit tests."""
import
numpy
as
np
from
numpy.testing
import
assert_allclose
import
inspect
import
six
from
.generic_utils
import
has_arg
from
..engine
import
Model
,
Input
from
..models
import
Sequential
from
..models
import
model_from_json
...
...
@@ -71,9 +71,7 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
layer
.
set_weights
(
weights
)
# test and instantiation from weights
# Checking for empty weights array to avoid a problem where some
# legacy layers return bad values from get_weights()
if
has_arg
(
layer_cls
.
__init__
,
'weights'
)
and
len
(
weights
):
if
'weights'
in
inspect
.
getargspec
(
layer_cls
.
__init__
):
kwargs
[
'weights'
]
=
weights
layer
=
layer_cls
(
**
kwargs
)
...
...
keras/wrappers/scikit_learn.py
浏览文件 @
4135aeeb
from
__future__
import
absolute_import
import
copy
import
inspect
import
types
import
numpy
as
np
from
..utils.np_utils
import
to_categorical
from
..utils.generic_utils
import
has_arg
from
..models
import
Sequential
...
...
@@ -75,11 +75,13 @@ class BaseWrapper(object):
else
:
legal_params_fns
.
append
(
self
.
build_fn
)
legal_params
=
[]
for
fn
in
legal_params_fns
:
legal_params
+=
inspect
.
getargspec
(
fn
)[
0
]
legal_params
=
set
(
legal_params
)
for
params_name
in
params
:
for
fn
in
legal_params_fns
:
if
has_arg
(
fn
,
params_name
):
break
else
:
if
params_name
not
in
legal_params
:
if
params_name
!=
'nb_epoch'
:
raise
ValueError
(
'{} is not a legal parameter'
.
format
(
params_name
))
...
...
@@ -161,8 +163,9 @@ class BaseWrapper(object):
"""
override
=
override
or
{}
res
=
{}
fn_args
=
inspect
.
getargspec
(
fn
)[
0
]
for
name
,
value
in
self
.
sk_params
.
items
():
if
has_arg
(
fn
,
name
)
:
if
name
in
fn_args
:
res
.
update
({
name
:
value
})
res
.
update
(
override
)
return
res
...
...
tests/keras/utils/generic_utils_test.py
浏览文件 @
4135aeeb
import
sys
import
pytest
from
keras.utils.generic_utils
import
custom_object_scope
,
has_arg
from
keras.utils.generic_utils
import
custom_object_scope
from
keras
import
activations
from
keras
import
regularizers
...
...
@@ -21,46 +20,5 @@ def test_custom_objects_scope():
assert
cl
.
__class__
==
CustomClass
@
pytest
.
mark
.
parametrize
(
'fn, name, accept_all, expected'
,
[
(
'f(x)'
,
'x'
,
False
,
True
),
(
'f(x)'
,
'y'
,
False
,
False
),
(
'f(x)'
,
'y'
,
True
,
False
),
(
'f(x, y)'
,
'y'
,
False
,
True
),
(
'f(x, y=1)'
,
'y'
,
False
,
True
),
(
'f(x, **kwargs)'
,
'x'
,
False
,
True
),
(
'f(x, **kwargs)'
,
'y'
,
False
,
False
),
(
'f(x, **kwargs)'
,
'y'
,
True
,
True
),
(
'f(x, y=1, **kwargs)'
,
'y'
,
False
,
True
),
# Keyword-only arguments (Python 3 only)
(
'f(x, *args, y=1)'
,
'y'
,
False
,
True
),
(
'f(x, *args, y=1)'
,
'z'
,
True
,
False
),
(
'f(x, *, y=1)'
,
'x'
,
False
,
True
),
(
'f(x, *, y=1)'
,
'y'
,
False
,
True
),
# lambda
(
lambda
x
:
x
,
'x'
,
False
,
True
),
(
lambda
x
:
x
,
'y'
,
False
,
False
),
(
lambda
x
:
x
,
'y'
,
True
,
False
),
])
def
test_has_arg
(
fn
,
name
,
accept_all
,
expected
):
if
isinstance
(
fn
,
str
):
context
=
dict
()
try
:
exec
(
'def {}: pass'
.
format
(
fn
),
context
)
except
SyntaxError
:
if
sys
.
version_info
>=
(
3
,):
raise
pytest
.
skip
(
'Function is not compatible with Python 2'
)
context
.
pop
(
'__builtins__'
,
None
)
# Sometimes exec adds builtins to the context
fn
,
=
context
.
values
()
assert
has_arg
(
fn
,
name
,
accept_all
)
is
expected
@
pytest
.
mark
.
xfail
(
sys
.
version_info
<
(
3
,
3
),
reason
=
'inspect API does not reveal positional-only arguments'
)
def
test_has_arg_positional_only
():
assert
has_arg
(
pow
,
'x'
)
is
False
if
__name__
==
'__main__'
:
pytest
.
main
([
__file__
])
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录