提交 25d0193e 编写于 作者: T tiferet 提交者: François Chollet

Supporting channels_first data format with crossentropy losses (#9715)

上级 1e80c1a2
......@@ -1823,7 +1823,23 @@ def softsign(x):
return x / (1 + C.abs(x))
def categorical_crossentropy(target, output, from_logits=False):
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
# Here, unlike other backends, the tensors lack a batch dimension:
axis_without_batch = -1 if axis == -1 else axis - 1
output_dimensions = list(range(len(output.shape)))
if axis_without_batch != -1 and axis_without_batch not in output_dimensions:
raise ValueError(
'{}{}{}'.format(
'Unexpected channels axis {}. '.format(axis_without_batch),
'Expected to be -1 or one of the axes of `output`, ',
'which has {} dimensions.'.format(len(output.shape))))
# If the channels are not in the last axis, move them to be there:
if axis_without_batch != -1 and axis_without_batch != output_dimensions[-1]:
permutation = output_dimensions[:axis_without_batch]
permutation += output_dimensions[axis_without_batch + 1:]
permutation += [axis_without_batch]
output = C.transpose(output, permutation)
target = C.transpose(target, permutation)
if from_logits:
result = C.cross_entropy_with_softmax(output, target)
# cntk's result shape is (batch, 1), while keras expect (batch, )
......@@ -1836,10 +1852,20 @@ def categorical_crossentropy(target, output, from_logits=False):
return -sum(target * C.log(output), axis=-1)
def sparse_categorical_crossentropy(target, output, from_logits=False):
target = C.one_hot(target, output.shape[-1])
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
# Here, unlike other backends, the tensors lack a batch dimension:
axis_without_batch = -1 if axis == -1 else axis - 1
output_dimensions = list(range(len(output.shape)))
if axis_without_batch != -1 and axis_without_batch not in output_dimensions:
raise ValueError(
'{}{}{}'.format(
'Unexpected channels axis {}. '.format(axis_without_batch),
'Expected to be -1 or one of the axes of `output`, ',
'which has {} dimensions.'.format(len(output.shape))))
target = C.one_hot(target, output.shape[axis_without_batch],
axis=axis_without_batch)
target = C.reshape(target, output.shape)
return categorical_crossentropy(target, output, from_logits)
return categorical_crossentropy(target, output, from_logits, axis=axis)
class Function(object):
......
......@@ -3173,7 +3173,7 @@ def softsign(x):
return tf.nn.softsign(x)
def categorical_crossentropy(target, output, from_logits=False):
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy between an output tensor and a target tensor.
# Arguments
......@@ -3183,28 +3183,40 @@ def categorical_crossentropy(target, output, from_logits=False):
case `output` is expected to be the logits).
from_logits: Boolean, whether `output` is the
result of a softmax, or is a tensor of logits.
axis: Int specifying the channels axis. `axis=-1`
corresponds to data format `channels_last`,
and `axis=1` corresponds to data format
`channels_first`.
# Returns
Output tensor.
"""
# Raises
ValueError: if `axis` is neither -1 nor one of
the axes of `output`.
"""
output_dimensions = list(range(len(output.get_shape())))
if axis != -1 and axis not in output_dimensions:
raise ValueError(
'{}{}{}'.format(
'Unexpected channels axis {}. '.format(axis),
'Expected to be -1 or one of the axes of `output`, ',
'which has {} dimensions.'.format(len(output.get_shape()))))
# Note: tf.nn.softmax_cross_entropy_with_logits
# expects logits, Keras expects probabilities.
if not from_logits:
# scale preds so that the class probas of each sample sum to 1
output /= tf.reduce_sum(output,
len(output.get_shape()) - 1,
True)
output /= tf.reduce_sum(output, axis, True)
# manual computation of crossentropy
_epsilon = _to_tensor(epsilon(), output.dtype.base_dtype)
output = tf.clip_by_value(output, _epsilon, 1. - _epsilon)
return - tf.reduce_sum(target * tf.log(output),
len(output.get_shape()) - 1)
return - tf.reduce_sum(target * tf.log(output), axis)
else:
return tf.nn.softmax_cross_entropy_with_logits(labels=target,
logits=output)
def sparse_categorical_crossentropy(target, output, from_logits=False):
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy with integer targets.
# Arguments
......@@ -3214,10 +3226,31 @@ def sparse_categorical_crossentropy(target, output, from_logits=False):
case `output` is expected to be the logits).
from_logits: Boolean, whether `output` is the
result of a softmax, or is a tensor of logits.
axis: Int specifying the channels axis. `axis=-1`
corresponds to data format `channels_last`,
and `axis=1` corresponds to data format
`channels_first`.
# Returns
Output tensor.
"""
# Raises
ValueError: if `axis` is neither -1 nor one of
the axes of `output`.
"""
output_dimensions = list(range(len(output.get_shape())))
if axis != -1 and axis not in output_dimensions:
raise ValueError(
'{}{}{}'.format(
'Unexpected channels axis {}. '.format(axis),
'Expected to be -1 or one of the axes of `output`, ',
'which has {} dimensions.'.format(len(output.get_shape()))))
# If the channels are not in the last axis, move them to be there:
if axis != -1 and axis != output_dimensions[-1]:
permutation = output_dimensions[:axis] + output_dimensions[axis + 1:]
permutation += [axis]
output = tf.transpose(output, perm=permutation)
# Note: tf.nn.sparse_softmax_cross_entropy_with_logits
# expects logits, Keras expects probabilities.
if not from_logits:
......
......@@ -1619,7 +1619,20 @@ def softsign(x):
return T_softsign(x)
def categorical_crossentropy(target, output, from_logits=False):
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
output_dimensions = list(range(len(int_shape(output))))
if axis != -1 and axis not in output_dimensions:
raise ValueError(
'{}{}{}'.format(
'Unexpected channels axis {}. '.format(axis),
'Expected to be -1 or one of the axes of `output`, ',
'which has {} dimensions.'.format(len(int_shape(output)))))
# If the channels are not in the last axis, move them to be there:
if axis != -1 and axis != output_dimensions[-1]:
permutation = output_dimensions[:axis]
permutation += output_dimensions[axis + 1:] + [axis]
output = permute_dimensions(output, permutation)
target = permute_dimensions(target, permutation)
if from_logits:
output = T.nnet.softmax(output)
else:
......@@ -1630,11 +1643,24 @@ def categorical_crossentropy(target, output, from_logits=False):
return T.nnet.categorical_crossentropy(output, target)
def sparse_categorical_crossentropy(target, output, from_logits=False):
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
output_dimensions = list(range(len(int_shape(output))))
if axis != -1 and axis not in output_dimensions:
raise ValueError(
'{}{}{}'.format(
'Unexpected channels axis {}. '.format(axis),
'Expected to be -1 or one of the axes of `output`, ',
'which has {} dimensions.'.format(len(int_shape(output)))))
# If the channels are not in the last axis, move them to be there:
if axis != -1 and axis != output_dimensions[-1]:
permutation = output_dimensions[:axis]
permutation += output_dimensions[axis + 1:] + [axis]
output = permute_dimensions(output, permutation)
target = permute_dimensions(target, permutation)
target = T.cast(T.flatten(target), 'int32')
target = T.extra_ops.to_one_hot(target, nb_class=output.shape[-1])
target = reshape(target, shape(output))
return categorical_crossentropy(target, output, from_logits)
return categorical_crossentropy(target, output, from_logits, axis=-1)
def binary_crossentropy(target, output, from_logits=False):
......
......@@ -767,7 +767,12 @@ class Model(Network):
for output_shape, loss_fn in zip(self._feed_output_shapes,
self._feed_loss_fns):
if loss_fn is losses.sparse_categorical_crossentropy:
feed_output_shapes.append(output_shape[:-1] + (1,))
if K.image_data_format() == 'channels_first' and len(
output_shape) in [4, 5]:
feed_output_shapes.append(
(output_shape[0], 1) + output_shape[2:])
else:
feed_output_shapes.append(output_shape[:-1] + (1,))
elif (not hasattr(loss_fn, '__name__') or
getattr(losses, loss_fn.__name__, None) is None):
# If `loss_fn` is not a function (e.g. callable class)
......
......@@ -17,7 +17,8 @@ def to_categorical(y, num_classes=None):
num_classes: total number of classes.
# Returns
A binary matrix representation of the input.
A binary matrix representation of the input. The classes axis
is placed last.
"""
y = np.array(y, dtype='int')
input_shape = y.shape
......
......@@ -7,7 +7,7 @@ import scipy.sparse as sparse
import keras
from keras import losses
from keras.layers import Dense, Dropout
from keras.layers import Dense, Dropout, Conv2D
from keras.engine import Input
from keras.engine.training import Model
from keras.engine import training_utils
......@@ -1361,5 +1361,87 @@ def test_training_and_eval_methods_on_symbolic_tensors_multi_io():
model.test_on_batch([input_a_tf, input_b_tf], [output_d_tf, output_e_tf])
@keras_test
def test_model_with_crossentropy_losses_channels_first():
"""Tests use of all crossentropy losses with `channels_first`.
Tests `sparse_categorical_crossentropy`, `categorical_crossentropy`,
and `binary_crossentropy`.
Verifies that evaluate gives the same result with either
`channels_first` or `channels_last` image_data_format.
Tests PR #9715.
"""
def prepare_simple_model(input_tensor, loss_name, target):
axis = 1 if K.image_data_format() == 'channels_first' else -1
if loss_name == 'sparse_categorical_crossentropy':
loss = lambda y_true, y_pred: K.sparse_categorical_crossentropy(
y_true, y_pred, axis=axis)
num_channels = np.amax(target) + 1
activation = 'softmax'
elif loss_name == 'categorical_crossentropy':
loss = lambda y_true, y_pred: K.categorical_crossentropy(
y_true, y_pred, axis=axis)
num_channels = target.shape[axis]
activation = 'softmax'
elif loss_name == 'binary_crossentropy':
loss = lambda y_true, y_pred: K.binary_crossentropy(y_true, y_pred)
num_channels = target.shape[axis]
activation = 'sigmoid'
predictions = Conv2D(num_channels, 1, activation=activation,
kernel_initializer='ones',
bias_initializer='ones')(input_tensor)
simple_model = Model(inputs=input_tensor, outputs=predictions)
simple_model.compile(optimizer='rmsprop', loss=loss)
return simple_model
losses_to_test = ['sparse_categorical_crossentropy',
'categorical_crossentropy', 'binary_crossentropy']
data_channels_first = np.array([[[[8., 7.1, 0.], [4.5, 2.6, 0.55],
[0.9, 4.2, 11.2]]]], dtype=np.float32)
# Labels for testing 4-class sparse_categorical_crossentropy, 4-class
# categorical_crossentropy, and 2-class binary_crossentropy:
labels_channels_first = [np.array([[[[0, 1, 3], [2, 1, 0], [2, 2, 1]]]]),
np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 0]],
[[1, 0, 0], [0, 0, 1], [0, 1, 0]],
[[0, 0, 0], [1, 0, 0], [0, 0, 1]],
[[0, 0, 1], [0, 0, 0], [1, 0, 0]]]]),
np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 1]],
[[1, 0, 1], [1, 0, 1], [1, 1, 0]]]])]
# Compute one loss for each loss function in the list `losses_to_test`:
loss_channels_last = [0., 0., 0.]
loss_channels_first = [0., 0., 0.]
old_data_format = K.image_data_format()
# Evaluate a simple network with channels last, with all three loss
# functions:
K.set_image_data_format('channels_last')
data = np.moveaxis(data_channels_first, 1, -1)
for index, loss_function in enumerate(losses_to_test):
labels = np.moveaxis(labels_channels_first[index], 1, -1)
inputs = Input(shape=(3, 3, 1))
model = prepare_simple_model(inputs, loss_function, labels)
loss_channels_last[index] = model.evaluate(x=data, y=labels,
batch_size=1, verbose=0)
# Evaluate the same network with channels first, with all three loss
# functions:
K.set_image_data_format('channels_first')
data = data_channels_first
for index, loss_function in enumerate(losses_to_test):
labels = labels_channels_first[index]
inputs = Input(shape=(1, 3, 3))
model = prepare_simple_model(inputs, loss_function, labels)
loss_channels_first[index] = model.evaluate(x=data, y=labels,
batch_size=1, verbose=0)
K.set_image_data_format(old_data_format)
assert_allclose(loss_channels_first, loss_channels_last,
err_msg='{}{}'.format('Computed different losses for ',
'channels_first and channels_last.'))
if __name__ == '__main__':
pytest.main([__file__])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册