提交 f2bbf98b 编写于 作者: F François Chollet

Fix issue where the disable_tracking decorator obfuscates layer constructors.

上级 9ad5a18f
......@@ -73,8 +73,6 @@ nav:
- Baby RNN: examples/babi_rnn.md
- Baby MemNN: examples/babi_memnn.md
- CIFAR-10 CNN: examples/cifar10_cnn.md
- CIFAR-10 CNN-Capsule: examples/cifar10_cnn_capsule.md
- CIFAR-10 CNN with augmentation (TF): examples/cifar10_cnn_tfaugment2d.md
- CIFAR-10 ResNet: examples/cifar10_resnet.md
- Convolution filter visualization: examples/conv_filter_visualization.md
- Convolutional LSTM: examples/conv_lstm.md
......
......@@ -5,7 +5,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ..engine.base_layer import Layer, InputSpec, disable_tracking
from ..engine.base_layer import Layer, InputSpec
from .. import initializers
from .. import regularizers
from .. import constraints
......
......@@ -46,7 +46,6 @@ class StackedRNNCells(Layer):
```
"""
@disable_tracking
def __init__(self, cells, **kwargs):
for cell in cells:
if not hasattr(cell, 'call'):
......@@ -391,7 +390,6 @@ class RNN(Layer):
```
"""
@disable_tracking
def __init__(self, cell,
return_sequences=False,
return_state=False,
......@@ -410,7 +408,7 @@ class RNN(Layer):
'(tuple of integers, '
'one integer per RNN state).')
super(RNN, self).__init__(**kwargs)
self.cell = cell
self._set_cell(cell)
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
......@@ -424,6 +422,13 @@ class RNN(Layer):
self.constants_spec = None
self._num_constants = None
@disable_tracking
def _set_cell(self, cell):
# This is isolated in its own method in order to use
# the disable_tracking decorator without altering the
# visible signature of __init__.
self.cell = cell
@property
def states(self):
if self._states is None:
......
......@@ -360,18 +360,12 @@ class Bidirectional(Wrapper):
```
"""
@disable_tracking
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]:
raise ValueError('Invalid merge mode. '
'Merge mode should be one of '
'{"sum", "mul", "ave", "concat", None}')
self.forward_layer = copy.copy(layer)
config = layer.get_config()
config['go_backwards'] = not config['go_backwards']
self.backward_layer = layer.__class__.from_config(config)
self.forward_layer.name = 'forward_' + self.forward_layer.name
self.backward_layer.name = 'backward_' + self.backward_layer.name
self._set_sublayers(layer)
self.merge_mode = merge_mode
if weights:
nw = len(weights)
......@@ -386,6 +380,18 @@ class Bidirectional(Wrapper):
self.input_spec = layer.input_spec
self._num_constants = None
@disable_tracking
def _set_sublayers(self, layer):
# This is isolated in its own method in order to use
# the disable_tracking decorator without altering the
# visible signature of __init__.
self.forward_layer = copy.copy(layer)
config = layer.get_config()
config['go_backwards'] = not config['go_backwards']
self.backward_layer = layer.__class__.from_config(config)
self.forward_layer.name = 'forward_' + self.forward_layer.name
self.backward_layer.name = 'backward_' + self.backward_layer.name
@property
def trainable(self):
return self._trainable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册