提交 1c4bb1ed 编写于 作者: K Kevin Mader 提交者: François Chollet

Fixing HDF5Matrix `.dtype` and `.shape` properties (#10749)

* Fixing HDF5Matrix `.dtype` and `.shape` output

Adding code to the initialization function of `HDF5Matrix` so that `.dtype` and `.shape` reflect the actual values returned by the normalizer

* Update io_utils_test.py

adding tests for normalizer functions that change dtype and shape

* Update io_utils_test.py

removing unnecessary whitespace to make pep8 happy
上级 d4c1f696
......@@ -58,6 +58,12 @@ class HDF5Matrix(object):
else:
self.end = end
self.normalizer = normalizer
if self.normalizer is not None:
first_val = self.normalizer(self.data[0:1])
else:
first_val = self.data[0:1]
self._base_shape = first_val.shape[1:]
self._base_dtype = first_val.dtype
def __len__(self):
return self.end - self.start
......@@ -101,7 +107,7 @@ class HDF5Matrix(object):
# Returns
A numpy-style shape tuple.
"""
return (self.end - self.start,) + self.data.shape[1:]
return (self.end - self.start,) + self._base_shape
@property
def dtype(self):
......@@ -110,7 +116,7 @@ class HDF5Matrix(object):
# Returns
A numpy dtype string.
"""
return self.data.dtype
return self._base_dtype
@property
def ndim(self):
......
......@@ -104,6 +104,16 @@ def test_io_utils(in_tmpdir):
normalized_X_train = HDF5Matrix(h5_path, 'my_data', start=0, end=150, normalizer=normalizer)
assert np.isclose(normalized_X_train[0][0], X_train[0][0] + 1)
# test resizing normalizer
normalizer_rs = lambda x: x[:, ::2]
normalized_rs_X_train = HDF5Matrix(h5_path, 'my_data', start=0, end=150, normalizer=normalizer_rs)
assert (normalized_rs_X_train.shape[1] == 5)
# test dtype changing normalizer
normalizer_dtype = lambda x: x.astype(np.uint8)
normalized_dtype_X_train = HDF5Matrix(h5_path, 'my_data', start=0, end=150, normalizer=normalizer_dtype)
assert (normalized_dtype_X_train.dtype == np.uint8)
os.remove(h5_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册