diff --git a/keras/applications/imagenet_utils.py b/keras/applications/imagenet_utils.py index d6a3c23d7ab879566a440aa2c36b0706965d84ef..ab59cd8e2389170d2d34b5866b3b2d6a6c3d2ba9 100644 --- a/keras/applications/imagenet_utils.py +++ b/keras/applications/imagenet_utils.py @@ -38,7 +38,8 @@ def _preprocess_numpy_input(x, data_format, mode): # Returns Preprocessed Numpy array. """ - x = x.astype(K.floatx()) + if not issubclass(x.dtype.type, np.floating): + x = x.astype(K.floatx(), copy=False) if mode == 'tf': x /= 127.5 @@ -153,6 +154,9 @@ def preprocess_input(x, data_format=None, mode='caffe'): # Arguments x: Input Numpy or symbolic tensor, 3D or 4D. + The preprocessed data is written over the input data + if the data types are compatible. To avoid this + behaviour, `numpy.copy(x)` can be used. data_format: Data format of the image tensor/array. mode: One of "caffe", "tf" or "torch". - caffe: will convert the images from RGB to BGR, diff --git a/tests/keras/applications/imagenet_utils_test.py b/tests/keras/applications/imagenet_utils_test.py index 85674b0b92c7b649a7220d16ba0e82dd47658784..91481203a60731c7586edfe6fb330373ade2556c 100644 --- a/tests/keras/applications/imagenet_utils_test.py +++ b/tests/keras/applications/imagenet_utils_test.py @@ -38,6 +38,22 @@ def test_preprocess_input(): assert_allclose(out1, out2.transpose(1, 2, 0)) assert_allclose(out1int, out2int.transpose(1, 2, 0)) + # Test that writing over the input data works predictably + for mode in ['torch', 'tf']: + x = np.random.uniform(0, 255, (2, 10, 10, 3)) + xint = x.astype('int') + x2 = utils.preprocess_input(x, mode=mode) + xint2 = utils.preprocess_input(xint) + assert_allclose(x, x2) + assert xint.astype('float').max() != xint2.max() + # Caffe mode works differently from the others + x = np.random.uniform(0, 255, (2, 10, 10, 3)) + xint = x.astype('int') + x2 = utils.preprocess_input(x, data_format='channels_last', mode='caffe') + xint2 = utils.preprocess_input(xint) + assert_allclose(x, x2[..., ::-1]) + assert xint.astype('float').max() != xint2.max() + def test_preprocess_input_symbolic(): # Test image batch