From cb2814b4983fb96a684f7efd7f44cb69f6a1591a Mon Sep 17 00:00:00 2001 From: jiangzhiwen Date: Mon, 11 May 2020 16:47:54 +0800 Subject: [PATCH] flat_map first commit --- mindspore/dataset/engine/datasets.py | 44 ++++++++++++ .../dataset/test_flat_map/image_index.txt | 2 + .../ut/data/dataset/test_flat_map/images.txt | 3 + .../ut/data/dataset/test_flat_map/images1.txt | 3 + .../ut/data/dataset/test_flat_map/images2.txt | 3 + tests/ut/python/dataset/test_flat_map.py | 72 +++++++++++++++++++ 6 files changed, 127 insertions(+) create mode 100644 tests/ut/data/dataset/test_flat_map/image_index.txt create mode 100644 tests/ut/data/dataset/test_flat_map/images.txt create mode 100644 tests/ut/data/dataset/test_flat_map/images1.txt create mode 100644 tests/ut/data/dataset/test_flat_map/images2.txt create mode 100644 tests/ut/python/dataset/test_flat_map.py diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 3d6394980..bcf4085a6 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -268,6 +268,50 @@ class Dataset: """ return ShuffleDataset(self, buffer_size) + def flat_map(self, func): + """ + Maps `func` to each row in dataset and flatten the result. + + The specified `func` is a function that must take one 'Ndarray' as input + and return a 'Dataset'. + + Args: + func (function): A function that must take one 'Ndarray' as an argument and + return a 'Dataset'. + + Returns: + Dataset, applied by the function. + + Examples: + >>> import mindspore.dataset as ds + >>> import mindspore.dataset.transforms.nlp.utils as nlp + >>> # declare a function which returns a Dataset object + >>> def flat_map_func(x): + >>> data_dir = nlp.as_text(x[0]) + >>> d = ds.ImageFolderDatasetV2(data_dir) + >>> return d + >>> # data is a Dataset object + >>> data = ds.TextFileDataset(DATA_FILE) + >>> data = data.flat_map(flat_map_func) + + Raises: + TypeError: If `func` is not a function. + TypeError: If `func` doesn't return a Dataset. + """ + dataset = None + if not hasattr(func, '__call__'): + raise TypeError("func must be a function.") + + for row_data in self: + if dataset is None: + dataset = func(row_data) + else: + dataset += func(row_data) + + if not isinstance(dataset, Dataset): + raise TypeError("flat_map must return a Dataset object.") + return dataset + @check_map def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, num_parallel_workers=None, python_multiprocessing=False): diff --git a/tests/ut/data/dataset/test_flat_map/image_index.txt b/tests/ut/data/dataset/test_flat_map/image_index.txt new file mode 100644 index 000000000..dd43bd128 --- /dev/null +++ b/tests/ut/data/dataset/test_flat_map/image_index.txt @@ -0,0 +1,2 @@ +../data/dataset/test_flat_map/images1.txt +../data/dataset/test_flat_map/images2.txt \ No newline at end of file diff --git a/tests/ut/data/dataset/test_flat_map/images.txt b/tests/ut/data/dataset/test_flat_map/images.txt new file mode 100644 index 000000000..0aca4507d --- /dev/null +++ b/tests/ut/data/dataset/test_flat_map/images.txt @@ -0,0 +1,3 @@ +../data/dataset/testPK/data +../data/dataset/testImageNetData/train +../data/dataset/testImageNetData2/train \ No newline at end of file diff --git a/tests/ut/data/dataset/test_flat_map/images1.txt b/tests/ut/data/dataset/test_flat_map/images1.txt new file mode 100644 index 000000000..0aca4507d --- /dev/null +++ b/tests/ut/data/dataset/test_flat_map/images1.txt @@ -0,0 +1,3 @@ +../data/dataset/testPK/data +../data/dataset/testImageNetData/train +../data/dataset/testImageNetData2/train \ No newline at end of file diff --git a/tests/ut/data/dataset/test_flat_map/images2.txt b/tests/ut/data/dataset/test_flat_map/images2.txt new file mode 100644 index 000000000..0aca4507d --- /dev/null +++ b/tests/ut/data/dataset/test_flat_map/images2.txt @@ -0,0 +1,3 @@ +../data/dataset/testPK/data +../data/dataset/testImageNetData/train +../data/dataset/testImageNetData2/train \ No newline at end of file diff --git a/tests/ut/python/dataset/test_flat_map.py b/tests/ut/python/dataset/test_flat_map.py new file mode 100644 index 000000000..5790ab7e8 --- /dev/null +++ b/tests/ut/python/dataset/test_flat_map.py @@ -0,0 +1,72 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import mindspore.dataset as ds + +DATA_FILE = "../data/dataset/test_flat_map/images1.txt" +INDEX_FILE = "../data/dataset/test_flat_map/image_index.txt" + + +def test_flat_map_1(): + ''' + DATA_FILE records the path of image folders, load the images from them. + ''' + import mindspore.dataset.transforms.nlp.utils as nlp + + def flat_map_func(x): + data_dir = nlp.as_text(x[0]) + d = ds.ImageFolderDatasetV2(data_dir) + return d + + data = ds.TextFileDataset(DATA_FILE) + data = data.flat_map(flat_map_func) + + count = 0 + for d in data: + assert isinstance(d[0], np.ndarray) + count += 1 + assert count == 52 + + +def test_flat_map_2(): + ''' + Flatten 3D structure data + ''' + import mindspore.dataset.transforms.nlp.utils as nlp + + def flat_map_func_1(x): + data_dir = nlp.as_text(x[0]) + d = ds.ImageFolderDatasetV2(data_dir) + return d + + def flat_map_func_2(x): + text_file = nlp.as_text(x[0]) + d = ds.TextFileDataset(text_file) + d = d.flat_map(flat_map_func_1) + return d + + data = ds.TextFileDataset(INDEX_FILE) + data = data.flat_map(flat_map_func_2) + + count = 0 + for d in data: + assert isinstance(d[0], np.ndarray) + count += 1 + assert count == 104 + + +if __name__ == "__main__": + test_flat_map_1() + test_flat_map_2() -- GitLab