提交 978efb61 编写于 作者: V vkk800 提交者: François Chollet

Fix undefined behaviour: preprocess_input copying/not copying the input arrays (#10153)

* Add copy option for image preprocessing

* Fix unnecessary import

* Fix style

* fix test error

* Make modifications in-place instead
上级 52854217
......@@ -38,7 +38,8 @@ def _preprocess_numpy_input(x, data_format, mode):
# Returns
Preprocessed Numpy array.
"""
x = x.astype(K.floatx())
if not issubclass(x.dtype.type, np.floating):
x = x.astype(K.floatx(), copy=False)
if mode == 'tf':
x /= 127.5
......@@ -153,6 +154,9 @@ def preprocess_input(x, data_format=None, mode='caffe'):
# Arguments
x: Input Numpy or symbolic tensor, 3D or 4D.
The preprocessed data is written over the input data
if the data types are compatible. To avoid this
behaviour, `numpy.copy(x)` can be used.
data_format: Data format of the image tensor/array.
mode: One of "caffe", "tf" or "torch".
- caffe: will convert the images from RGB to BGR,
......
......@@ -38,6 +38,22 @@ def test_preprocess_input():
assert_allclose(out1, out2.transpose(1, 2, 0))
assert_allclose(out1int, out2int.transpose(1, 2, 0))
# Test that writing over the input data works predictably
for mode in ['torch', 'tf']:
x = np.random.uniform(0, 255, (2, 10, 10, 3))
xint = x.astype('int')
x2 = utils.preprocess_input(x, mode=mode)
xint2 = utils.preprocess_input(xint)
assert_allclose(x, x2)
assert xint.astype('float').max() != xint2.max()
# Caffe mode works differently from the others
x = np.random.uniform(0, 255, (2, 10, 10, 3))
xint = x.astype('int')
x2 = utils.preprocess_input(x, data_format='channels_last', mode='caffe')
xint2 = utils.preprocess_input(xint)
assert_allclose(x, x2[..., ::-1])
assert xint.astype('float').max() != xint2.max()
def test_preprocess_input_symbolic():
# Test image batch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册