提交 da9ce7d8 编写于 作者: B Brian Nemsick 提交者: François Chollet

Fix Stateful Metrics in fit_generator with TensorBoard (#10673)

* Fix casting validate generator with stateful metrics.

* Add a stateful metric to Tensboard tests
上级 1c4bb1ed
......@@ -375,7 +375,7 @@ def evaluate_generator(model, generator,
averages.append(np.average([out[i] for out in outs_per_batch],
weights=batch_sizes))
else:
averages.append(float(outs_per_batch[-1][i]))
averages.append(np.float64(outs_per_batch[-1][i]))
return unpack_singleton(averages)
......
......@@ -10,7 +10,7 @@ from keras import optimizers
from keras import initializers
from keras import callbacks
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, add, dot, Lambda
from keras.layers import Input, Dense, Dropout, add, dot, Lambda, Layer
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D, GlobalAveragePooling1D, GlobalAveragePooling2D
from keras.utils.test_utils import get_test_data
......@@ -500,6 +500,19 @@ def test_TensorBoard(tmpdir):
i += 1
i = i % max_batch_index
class DummyStatefulMetric(Layer):
def __init__(self, name='dummy_stateful_metric', **kwargs):
super(DummyStatefulMetric, self).__init__(name=name, **kwargs)
self.stateful = True
self.state = K.variable(value=0, dtype='int32')
def reset_states(self):
pass
def __call__(self, y_true, y_pred):
return self.state
inp = Input((input_dim,))
hidden = Dense(num_hidden, activation='relu')(inp)
hidden = Dropout(0.1)(hidden)
......@@ -507,7 +520,7 @@ def test_TensorBoard(tmpdir):
model = Model(inputs=inp, outputs=output)
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
metrics=['accuracy', DummyStatefulMetric()])
# we must generate new callbacks for each test, as they aren't stateless
def callbacks_factory(histogram_freq, embeddings_freq=1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册