提交 38e38315 编写于 作者: F François Chollet

Add metrics CosineSimilarity, MeanAbsoluteError, MeanAbsolutePercentageError,...

Add metrics CosineSimilarity, MeanAbsoluteError, MeanAbsolutePercentageError, MeanSquaredError, MeanSquaredLogarithmicError, RootMeanSquaredError.
上级 680be2e1
......@@ -718,10 +718,10 @@ def poisson(y_true, y_pred):
return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1)
def cosine_proximity(y_true, y_pred):
y_true = K.l2_normalize(y_true, axis=-1)
y_pred = K.l2_normalize(y_pred, axis=-1)
return -K.sum(y_true * y_pred, axis=-1)
def cosine_proximity(y_true, y_pred, axis=-1):
y_true = K.l2_normalize(y_true, axis=axis)
y_pred = K.l2_normalize(y_pred, axis=axis)
return - K.sum(y_true * y_pred, axis=axis)
def _maybe_convert_labels(y_true):
......@@ -753,7 +753,7 @@ mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
kld = KLD = kullback_leibler_divergence
cosine = cosine_proximity
cosine = cosine_similarity = cosine_proximity
def is_categorical_crossentropy(loss):
......
......@@ -24,6 +24,7 @@ from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .losses import cosine_similarity
from .utils import losses_utils
from .utils import metrics_utils
from .utils.generic_utils import deserialize_keras_object
......@@ -670,6 +671,150 @@ class KLDivergence(MeanMetricWrapper):
kullback_leibler_divergence, name, dtype=dtype)
class CosineSimilarity(MeanMetricWrapper):
"""Computes the cosine similarity between the labels and predictions.
cosine similarity = (a . b) / ||a|| ||b||
[Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
For example, if `y_true` is [0, 1, 1], and `y_pred` is [1, 0, 1], the cosine
similarity is 0.5.
This metric keeps the average cosine similarity between `predictions` and
`labels` over a stream of data.
# Arguments
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
axis: (Optional) Defaults to -1. The dimension along which the cosine
similarity is computed.
Usage with the `compile` API:
```python
model = keras.Model(inputs, outputs)
model.compile(
'sgd',
loss='mse',
metrics=[keras.metrics.CosineSimilarity(axis=1)])
```
"""
def __init__(self, name='cosine_similarity', dtype=None, axis=-1):
super(CosineSimilarity, self).__init__(
cosine_similarity, name, dtype=dtype, axis=axis)
class MeanAbsoluteError(MeanMetricWrapper):
"""Computes the mean absolute error between the labels and predictions.
Usage with the `compile` API:
```python
model = keras.Model(inputs, outputs)
model.compile('sgd', metrics=[keras.metrics.MeanAbsoluteError()])
```
"""
def __init__(self, name='mean_absolute_error', dtype=None):
super(MeanAbsoluteError, self).__init__(
mean_absolute_error, name, dtype=dtype)
class MeanAbsolutePercentageError(MeanMetricWrapper):
"""Computes the mean absolute percentage error between `y_true` and `y_pred`.
For example, if `y_true` is [0., 0., 1., 1.], and `y_pred` is [1., 1., 1., 0.]
the mean absolute percentage error is 5e+08.
Usage:
```python
m = keras.metrics.MeanAbsolutePercentageError()
m.update_state([0., 0., 1., 1.], [1., 1., 1., 0.])
# Final result: 5e+08
```
Usage with the `compile` API:
```python
model = keras.Model(inputs, outputs)
model.compile('sgd', metrics=[keras.metrics.MeanAbsolutePercentageError()])
```
"""
def __init__(self, name='mean_absolute_percentage_error', dtype=None):
super(MeanAbsolutePercentageError, self).__init__(
mean_absolute_percentage_error, name, dtype=dtype)
class MeanSquaredError(MeanMetricWrapper):
"""Computes the mean squared error between `y_true` and `y_pred`.
Usage with the `compile` API:
```python
model = keras.Model(inputs, outputs)
model.compile('sgd', metrics=[keras.metrics.MeanSquaredError()])
```
"""
def __init__(self, name='mean_squared_error', dtype=None):
super(MeanSquaredError, self).__init__(
mean_squared_error, name, dtype=dtype)
class MeanSquaredLogarithmicError(MeanMetricWrapper):
"""Computes the mean squared logarithmic error between `y_true` and `y_pred`.
Usage with the `compile` API:
```python
model = keras.Model(inputs, outputs)
model.compile('sgd', metrics=[keras.metrics.MeanSquaredLogarithmicError()])
```
"""
def __init__(self, name='mean_squared_logarithmic_error', dtype=None):
super(MeanSquaredLogarithmicError, self).__init__(
mean_squared_logarithmic_error, name, dtype=dtype)
class RootMeanSquaredError(Mean):
"""Computes root mean squared error metric between `y_true` and `y_pred`.
Usage with the `compile` API:
```python
model = keras.Model(inputs, outputs)
model.compile('sgd', metrics=[keras.metrics.RootMeanSquaredError()])
```
"""
def __init__(self, name='root_mean_squared_error', dtype=None):
super(RootMeanSquaredError, self).__init__(name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates root mean squared error statistics.
# Arguments
y_true: The ground truth values.
y_pred: The predicted values.
sample_weight: Optional weighting of each example. Defaults to 1.
Can be a `Tensor` whose rank is either 0,
or the same rank as `y_true`,
and must be broadcastable to `y_true`.
# Returns
List of update ops.
"""
error_sq = K.square(y_pred - y_true)
return super(RootMeanSquaredError, self).update_state(
error_sq, sample_weight=sample_weight)
def result(self):
return K.sqrt(self.total / self.count)
class BinaryCrossentropy(MeanMetricWrapper):
"""Computes the crossentropy metric between the labels and predictions.
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册