提交 5741f4b9 编写于 作者: S Scott Zhu 提交者: TensorFlower Gardener

Fix GRU cell breakage when reset_after=True in eager mode.

Also added unit test to cover that.

PiperOrigin-RevId: 225028823
上级 4143d8d3
......@@ -81,6 +81,29 @@ class GRULayerTest(test.TestCase):
'implementation': mode},
input_shape=(num_samples, timesteps, embedding_dim))
@tf_test_util.run_in_graph_and_eager_modes
def test_reset_after_GRU(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=num_samples,
test_samples=0,
input_shape=(timesteps, embedding_dim),
num_classes=units)
y_train = keras.utils.to_categorical(y_train, units)
inputs = keras.layers.Input(shape=[timesteps, embedding_dim])
gru_layer = keras.layers.GRU(units,
reset_after=True)
output = gru_layer(inputs)
gru_model = keras.models.Model(inputs, output)
gru_model.compile('rmsprop', 'mse')
gru_model.fit(x_train, y_train)
gru_model.predict(x_train)
def test_statefulness_GRU(self):
num_samples = 2
timesteps = 3
......
......@@ -1497,12 +1497,6 @@ class GRUCell(Layer):
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
if not self.reset_after:
self.input_bias, self.recurrent_bias = self.bias, None
else:
self.input_bias = K.flatten(self.bias[0])
self.recurrent_bias = K.flatten(self.bias[1])
else:
self.bias = None
self.built = True
......@@ -1529,6 +1523,12 @@ class GRUCell(Layer):
# dropout matrices for recurrent units
rec_dp_mask = self._recurrent_dropout_mask
if self.use_bias:
if not self.reset_after:
input_bias, recurrent_bias = self.bias, None
else:
input_bias, recurrent_bias = array_ops.unstack(self.bias)
if self.implementation == 1:
if 0. < self.dropout < 1.:
inputs_z = inputs * dp_mask[0]
......@@ -1544,9 +1544,9 @@ class GRUCell(Layer):
x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:])
if self.use_bias:
x_z = K.bias_add(x_z, self.input_bias[:self.units])
x_r = K.bias_add(x_r, self.input_bias[self.units: self.units * 2])
x_h = K.bias_add(x_h, self.input_bias[self.units * 2:])
x_z = K.bias_add(x_z, input_bias[:self.units])
x_r = K.bias_add(x_r, input_bias[self.units: self.units * 2])
x_h = K.bias_add(x_h, input_bias[self.units * 2:])
if 0. < self.recurrent_dropout < 1.:
h_tm1_z = h_tm1 * rec_dp_mask[0]
......@@ -1561,10 +1561,9 @@ class GRUCell(Layer):
recurrent_r = K.dot(h_tm1_r,
self.recurrent_kernel[:, self.units:self.units * 2])
if self.reset_after and self.use_bias:
recurrent_z = K.bias_add(recurrent_z, self.recurrent_bias[:self.units])
recurrent_z = K.bias_add(recurrent_z, recurrent_bias[:self.units])
recurrent_r = K.bias_add(recurrent_r,
self.recurrent_bias[self.units:
self.units * 2])
recurrent_bias[self.units:self.units * 2])
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
......@@ -1573,8 +1572,7 @@ class GRUCell(Layer):
if self.reset_after:
recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
if self.use_bias:
recurrent_h = K.bias_add(recurrent_h,
self.recurrent_bias[self.units * 2:])
recurrent_h = K.bias_add(recurrent_h, recurrent_bias[self.units * 2:])
recurrent_h = r * recurrent_h
else:
recurrent_h = K.dot(r * h_tm1_h,
......@@ -1589,7 +1587,7 @@ class GRUCell(Layer):
matrix_x = K.dot(inputs, self.kernel)
if self.use_bias:
# biases: bias_z_i, bias_r_i, bias_h_i
matrix_x = K.bias_add(matrix_x, self.input_bias)
matrix_x = K.bias_add(matrix_x, input_bias)
x_z = matrix_x[:, :self.units]
x_r = matrix_x[:, self.units: 2 * self.units]
......@@ -1602,7 +1600,7 @@ class GRUCell(Layer):
# hidden state projected by all gate matrices at once
matrix_inner = K.dot(h_tm1, self.recurrent_kernel)
if self.use_bias:
matrix_inner = K.bias_add(matrix_inner, self.recurrent_bias)
matrix_inner = K.bias_add(matrix_inner, recurrent_bias)
else:
# hidden state projected separately for update/reset and new
matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册