提交 a3664246 编写于 作者: Z Zhengping Che 提交者: François Chollet

Handle `mask` in `TimeDistributed` wrapper. (#10242)

* equip TimeDistributed with mask and unspecified input length

* fix bugs in theano. add test on timedistributed + masking

* skip tests on cntk with multiple unspecified time lengths.

* move static shape inference to theano_backend, add docstring, etc.

* fix format
上级 7365a99f
......@@ -585,7 +585,26 @@ def var(x, axis=None, keepdims=False):
def any(x, axis=None, keepdims=False):
"""Bitwise reduction (logical OR).
"""
return T.any(x, axis=axis, keepdims=keepdims)
y = T.any(x, axis=axis, keepdims=keepdims)
if hasattr(x, '_keras_shape'):
if axis is None:
y._keras_shape = (1,) * len(x._keras_shape) if keepdims else (1,)
else:
if isinstance(axis, int):
axis_list = [axis]
else:
axis_list = list(set(int(a) for a in axis))
keras_shape_list = list(x._keras_shape)
if keepdims:
for a in axis_list:
keras_shape_list[a] = 1
else:
for a in axis_list[::-1]:
keras_shape_list.pop(a)
if not keras_shape_list:
keras_shape_list = (1,)
y._keras_shape = tuple(keras_shape_list)
return y
def all(x, axis=None, keepdims=False):
......@@ -671,7 +690,12 @@ def equal(x, y):
def not_equal(x, y):
return T.neq(x, y)
z = T.neq(x, y)
if hasattr(x, '_keras_shape'):
z._keras_shape = x._keras_shape
elif hasattr(y, '_keras_shape'):
z._keras_shape = y._keras_shape
return z
def greater(x, y):
......@@ -868,13 +892,12 @@ def concatenate(tensors, axis=-1):
def reshape(x, shape):
y = T.reshape(x, shape)
if _is_explicit_shape(shape):
shape = tuple(x if x != -1 else None for x in shape)
y._keras_shape = shape
if hasattr(x, '_uses_learning_phase'):
y._uses_learning_phase = x._uses_learning_phase
else:
y._uses_learning_phase = False
shape = tuple(x if isinstance(x, int) and x > 0 else None for x in shape)
y._keras_shape = shape
if hasattr(x, '_uses_learning_phase'):
y._uses_learning_phase = x._uses_learning_phase
else:
y._uses_learning_phase = False
return y
......
......@@ -60,7 +60,8 @@ class Masking(Layer):
self.mask_value = mask_value
def compute_mask(self, inputs, mask=None):
return K.any(K.not_equal(inputs, self.mask_value), axis=-1)
output_mask = K.any(K.not_equal(inputs, self.mask_value), axis=-1)
return output_mask
def call(self, inputs):
boolean_mask = K.any(K.not_equal(inputs, self.mask_value),
......
......@@ -93,6 +93,7 @@ class Embedding(Layer):
self.activity_regularizer = regularizers.get(activity_regularizer)
self.embeddings_constraint = constraints.get(embeddings_constraint)
self.mask_zero = mask_zero
self.supports_masking = mask_zero
self.input_length = input_length
def build(self, input_shape):
......@@ -108,8 +109,8 @@ class Embedding(Layer):
def compute_mask(self, inputs, mask=None):
if not self.mask_zero:
return None
else:
return K.not_equal(inputs, 0)
output_mask = K.not_equal(inputs, 0)
return output_mask
def compute_output_shape(self, input_shape):
if self.input_length is None:
......
......@@ -160,6 +160,37 @@ class TimeDistributed(Wrapper):
super(TimeDistributed, self).__init__(layer, **kwargs)
self.supports_masking = True
def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
"""Finds non-specific dimensions in the static shapes
and replaces them by the corresponding dynamic shapes of the tensor.
# Arguments
init_tuple: a tuple, the first part of the output shape
tensor: the tensor from which to get the (static and dynamic) shapes
as the last part of the output shape
start_idx: int, which indicate the first dimension to take from
the static shape of the tensor
int_shape: an alternative static shape to take as the last part
of the output shape
# Returns
The new int_shape with the first part from init_tuple
and the last part from either `int_shape` (if provided)
or K.int_shape(tensor), where every `None` is replaced by
the corresponding dimension from K.shape(tensor)
"""
# replace all None in int_shape by K.shape
if int_shape is None:
int_shape = K.int_shape(tensor)[start_idx:]
if not any(not s for s in int_shape):
return init_tuple + int_shape
tensor_shape = K.shape(tensor)
int_shape = list(int_shape)
for i, s in enumerate(int_shape):
if not s:
int_shape[i] = tensor_shape[start_idx + i]
return init_tuple + tuple(int_shape)
def build(self, input_shape):
assert len(input_shape) >= 3
self.input_spec = InputSpec(shape=input_shape)
......@@ -204,18 +235,24 @@ class TimeDistributed(Wrapper):
input_length = input_shape[1]
if not input_length:
input_length = K.shape(inputs)[1]
inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
# Shape: (num_samples * timesteps, ...). And track the
# transformation in self._input_map.
input_uid = object_list_uid(inputs)
inputs = K.reshape(inputs, (-1,) + input_shape[2:])
inputs = K.reshape(inputs, inner_input_shape)
self._input_map[input_uid] = inputs
# (num_samples * timesteps, ...)
if has_arg(self.layer.call, 'mask') and mask is not None:
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
kwargs['mask'] = K.reshape(mask, inner_mask_shape)
y = self.layer.call(inputs, **kwargs)
if hasattr(y, '_uses_learning_phase'):
uses_learning_phase = y._uses_learning_phase
# Shape: (num_samples, timesteps, ...)
output_shape = self.compute_output_shape(input_shape)
y = K.reshape(y, (-1, input_length) + output_shape[2:])
output_shape = self._get_shape_tuple(
(-1, input_length), y, 1, output_shape[2:])
y = K.reshape(y, output_shape)
# Apply activity regularizer if any:
if (hasattr(self.layer, 'activity_regularizer') and
......@@ -227,6 +264,70 @@ class TimeDistributed(Wrapper):
y._uses_learning_phase = True
return y
def compute_mask(self, inputs, mask=None):
"""Computes an output mask tensor for Embedding layer
based on the inputs, mask, and the inner layer.
If batch size is specified:
Simply return the input `mask`. (An rnn-based implementation with
more than one rnn inputs is required but not supported in Keras yet.)
Otherwise we call `compute_mask` of the inner layer at each time step.
If the output mask at each time step is not `None`:
(E.g., inner layer is Masking or RNN)
Concatenate all of them and return the concatenation.
If the output mask at each time step is `None` and the input mask is not `None`:
(E.g., inner layer is Dense)
Reduce the input_mask to 2 dimensions and return it.
Otherwise (both the output mask and the input mask are `None`):
(E.g., `mask` is not used at all)
Return `None`.
# Arguments
inputs: Tensor
mask: Tensor
# Returns
None or a tensor
"""
# cases need to call the layer.compute_mask when input_mask is None:
# Masking layer and Embedding layer with mask_zero
input_shape = K.int_shape(inputs)
if input_shape[0]:
# batch size matters, we currently do not handle mask explicitly
return mask
inner_mask = mask
if inner_mask is not None:
inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
inner_mask = K.reshape(inner_mask, inner_mask_shape)
input_uid = object_list_uid(inputs)
inner_inputs = self._input_map[input_uid]
output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
if output_mask is None:
if mask is None:
return None
# input_mask is not None, and output_mask is None:
# we should return a not-None mask
output_mask = mask
for _ in range(2, len(K.int_shape(mask))):
output_mask = K.any(output_mask, axis=-1)
else:
# output_mask is not None. We need to reshape it
input_length = input_shape[1]
if not input_length:
input_length = K.shape(inputs)[1]
output_mask_int_shape = K.int_shape(output_mask)
if output_mask_int_shape is None:
# if the output_mask does not have a static shape,
# its shape must be the same as mask's
if mask is not None:
output_mask_int_shape = K.int_shape(mask)
else:
output_mask_int_shape = K.compute_output_shape(input_shape)[:-1]
output_mask_shape = self._get_shape_tuple(
(-1, input_length), output_mask, 1, output_mask_int_shape[1:])
output_mask = K.reshape(output_mask, output_mask_shape)
return output_mask
class Bidirectional(Wrapper):
"""Bidirectional wrapper for RNNs.
......
......@@ -156,6 +156,59 @@ def test_TimeDistributed_trainable():
assert len(layer.trainable_weights) == 2
@keras_test
@pytest.mark.skipif((K.backend() == 'cntk'),
reason='Unknown timestamps for RNN not supported in CNTK.')
def test_TimeDistributed_with_masked_embedding_and_unspecified_shape():
# test with unspecified shape and Embeddings with mask_zero
model = Sequential()
model.add(wrappers.TimeDistributed(layers.Embedding(5, 6, mask_zero=True),
input_shape=(None, None))) # N by t_1 by t_2 by 6
model.add(wrappers.TimeDistributed(layers.SimpleRNN(7, return_sequences=True)))
model.add(wrappers.TimeDistributed(layers.SimpleRNN(8, return_sequences=False)))
model.add(layers.SimpleRNN(1, return_sequences=False))
model.compile(optimizer='rmsprop', loss='mse')
model_input = np.random.randint(low=1, high=5, size=(10, 3, 4), dtype='int32')
for i in range(4):
model_input[i, i:, i:] = 0
model.fit(model_input,
np.random.random((10, 1)), epochs=1, batch_size=10)
mask_outputs = [model.layers[0].compute_mask(model.input)]
for layer in model.layers[1:]:
mask_outputs.append(layer.compute_mask(layer.input, mask_outputs[-1]))
func = K.function([model.input], mask_outputs[:-1])
mask_outputs_val = func([model_input])
ref_mask_val_0 = model_input > 0 # embedding layer
ref_mask_val_1 = ref_mask_val_0 # first RNN layer
ref_mask_val_2 = np.any(ref_mask_val_1, axis=-1) # second RNN layer
ref_mask_val = [ref_mask_val_0, ref_mask_val_1, ref_mask_val_2]
for i in range(3):
assert np.array_equal(mask_outputs_val[i], ref_mask_val[i])
assert mask_outputs[-1] is None # final layer
@keras_test
def test_TimeDistributed_with_masking_layer():
# test with Masking layer
model = Sequential()
model.add(wrappers.TimeDistributed(layers.Masking(mask_value=0.,),
input_shape=(None, 4)))
model.add(wrappers.TimeDistributed(layers.Dense(5)))
model.compile(optimizer='rmsprop', loss='mse')
model_input = np.random.randint(low=1, high=5, size=(10, 3, 4))
for i in range(4):
model_input[i, i:, :] = 0.
model.compile(optimizer='rmsprop', loss='mse')
model.fit(model_input,
np.random.random((10, 3, 5)), epochs=1, batch_size=6)
mask_outputs = [model.layers[0].compute_mask(model.input)]
mask_outputs += [model.layers[1].compute_mask(model.layers[1].input, mask_outputs[-1])]
func = K.function([model.input], mask_outputs)
mask_outputs_val = func([model_input])
assert np.array_equal(mask_outputs_val[0], np.any(model_input, axis=-1))
assert np.array_equal(mask_outputs_val[1], np.any(model_input, axis=-1))
@keras_test
def test_regularizers():
model = Sequential()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册