提交 b51d81f8 编写于 作者: S Scott Zhu 提交者: TensorFlower Gardener

Update the default activation function for unified LSTM to 'sigmoid'.

I believe for historical reason, the activation function for LSTM is hard_sigmoid because it is faster compare to sigmoid. With the new LSTM, the performance issue should be fixed with grappler swapping the backend.

PiperOrigin-RevId: 224863406
上级 1d54cbf4
......@@ -2546,13 +2546,11 @@ class UnifiedLSTM(LSTM):
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation
is applied
(ie. "linear" activation: `a(x) = x`).
Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation
is applied (ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use for the recurrent step.
Default: hard sigmoid (`hard_sigmoid`). If you pass `None`, no
activation is applied
(ie. "linear" activation: `a(x) = x`).
Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
applied (ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix, used for
the linear transformation of the inputs..
......@@ -2602,7 +2600,7 @@ class UnifiedLSTM(LSTM):
def __init__(self,
units,
activation='tanh',
recurrent_activation='hard_sigmoid',
recurrent_activation='sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
......@@ -2663,8 +2661,9 @@ class UnifiedLSTM(LSTM):
self._num_inputs = None
self._dropout_mask = None
self.could_use_cudnn = (
activation == 'tanh' and recurrent_dropout == 0 and
not unroll and use_bias and bias_regularizer is None)
activation == 'tanh' and recurrent_activation == 'sigmoid' and
recurrent_dropout == 0 and not unroll and use_bias and
bias_regularizer is None)
def call(self, inputs, mask=None, training=None, initial_state=None):
# LSTM does not support constants. Ignore it during process.
......
......@@ -161,17 +161,20 @@ class UnifiedLSTMTest(test.TestCase, parameterized.TestCase):
existing_loss = loss_value
@parameterized.named_parameters(
('_non_tan_activation', 'relu', 0, False, True, None),
('_use_recurrent_dropout', 'tanh', 0.1, False, True, None),
('_unroll', 'tanh', 0, True, True, None),
('_not_use_bias', 'tanh', 0, False, False, None),
('_use_bias_regularizer', 'tanh', 0, False, True, 'l2')
('non_tan_activation', 'relu', 'sigmoid', 0, False, True, None),
('non_sigmoid_recur_activation', 'tanh', 'relu', 0, False, True, None),
('use_recurrent_dropout', 'tanh', 'sigmoid', 0.1, False, True, None),
('unroll', 'tanh', 'sigmoid', 0, True, True, None),
('not_use_bias', 'tanh', 'sigmoid', 0, False, False, None),
('use_bias_regularizer', 'tanh', 'sigmoid', 0, False, True, 'l2')
)
@test_util.run_in_graph_and_eager_modes(config=_config)
def test_could_use_defun_backend(self, activation, recurrent_dropout,
unroll, use_bias, bias_regularizer):
def test_could_use_defun_backend(self, activation, recurrent_activation,
recurrent_dropout, unroll, use_bias,
bias_regularizer):
layer = UnifiedLSTM(1,
activation=activation,
recurrent_activation=recurrent_activation,
recurrent_dropout=recurrent_dropout,
unroll=unroll,
use_bias=use_bias,
......@@ -270,22 +273,22 @@ class UnifiedLSTMTest(test.TestCase, parameterized.TestCase):
inputs = keras.layers.Input(
shape=[timestep, input_shape], dtype=dtypes.float32)
with test_util.device(use_gpu=False):
# Note that CuDNN use 'sigmoid' as activation. Force the CPU
# implementation to use 'sigmoid' so that it will generate same output as
# CuDNN implementation.
layer = UnifiedLSTM(rnn_state_size, recurrent_activation='sigmoid')
layer = UnifiedLSTM(rnn_state_size)
output = layer(inputs)
cpu_model = keras.models.Model(inputs, output)
weights = cpu_model.get_weights()
y_1 = cpu_model.predict(x_train)
with test_util.device(use_gpu=True):
layer = UnifiedLSTM(rnn_state_size, recurrent_activation='sigmoid')
layer = UnifiedLSTM(rnn_state_size)
output = layer(inputs)
gpu_model = keras.models.Model(inputs, output)
gpu_model.set_weights(weights)
y_2 = gpu_model.predict(x_train)
# Note that CuDNN uses 'sigmoid' as activation, so the unified LSTM uses
# 'sigmoid' as default. Construct the canonical LSTM with sigmoid to achieve
# the same output.
with test_util.device(use_gpu=True):
layer = keras.layers.LSTM(rnn_state_size, recurrent_activation='sigmoid')
output = layer(inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册