diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc index c578c295e5a26d1b43663c5e1160ac2a9b294031..3d79f3561d821e4caff893dbb536a90a7213bed5 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc @@ -44,7 +44,14 @@ PYBIND_REGISTER( [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) .def("SetBatchParameters", [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) - .def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); }) + .def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); }) + .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) + .def("GetColumnNames", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetColumnNames(&out)); + return out; + }) .def("GetNextAsMap", [](DEPipeline &de) { py::dict out; diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc index 443c669f4a7675d0abe3ccc47837447a0b42b241..57207f0676b344313491a8b0f914581aab9a2af7 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc @@ -172,9 +172,11 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr & // Function to assign the node as root. Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } +// Function to prepare the tree +Status DEPipeline::PrepareTree(const int32_t num_epochs) { return tree_->Prepare(num_epochs); } + // Function to launch the tree execution. -Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) { - RETURN_IF_NOT_OK(tree_->Prepare(num_epochs)); +Status DEPipeline::LaunchTreeExec() { RETURN_IF_NOT_OK(tree_->Launch()); iterator_ = std::make_unique(tree_); if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); @@ -189,6 +191,25 @@ void DEPipeline::PrintTree() { } } +Status DEPipeline::GetColumnNames(py::list *output) { + if (!tree_->isPrepared()) { + RETURN_STATUS_UNEXPECTED("GetColumnNames: Make sure to call prepare before calling GetColumnNames."); + } + std::unordered_map column_name_id_map = tree_->root()->column_name_id_map(); + if (column_name_id_map.empty()) + RETURN_STATUS_UNEXPECTED("GetColumnNames: Column names was empty. Make sure Prepare is called."); + std::vector> column_name_id_vector(column_name_id_map.begin(), + column_name_id_map.end()); + std::sort(column_name_id_vector.begin(), column_name_id_vector.end(), + [](const std::pair &a, const std::pair &b) { + return a.second < b.second; + }); + for (auto item : column_name_id_vector) { + (*output).append(item.first); + } + return Status::OK(); +} + Status DEPipeline::GetNextAsMap(py::dict *output) { TensorMap row; Status s; diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h index 80d524982ab6d07c6c2e829dbc762ce16fd88a92..80a2c3c2ab4d398533208579411a94822984dae5 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h @@ -92,8 +92,14 @@ class DEPipeline { // Function to assign the node as root. Status AssignRootNode(const DsOpPtr &dataset_op); + // Function to get the column names in the last node in the tree in order + Status GetColumnNames(py::list *output); + + // Function to prepare the tree for execution + Status PrepareTree(const int32_t num_epochs); + // Function to launch the tree execution. - Status LaunchTreeExec(int32_t num_epochs); + Status LaunchTreeExec(); // Get a row of data as dictionary of column name to the value. Status GetNextAsMap(py::dict *output); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index a70f54bdeed3bd85e2f57829a52d9e208818f5ef..4bd3fd1c5515a4d2ca479fd12c46af76438ff174 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -83,7 +83,8 @@ void GeneratorOp::Dealloc() noexcept { PyGILState_STATE gstate; gstate = PyGILState_Ensure(); // GC the generator object within GIL - (void)generator_.dec_ref(); + if (generator_function_.ref_count() == 1) generator_function_.dec_ref(); + if (generator_.ref_count() == 1) (void)generator_.dec_ref(); // Release GIL PyGILState_Release(gstate); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h index 41d51733c1a5e2f27a066aa1db723e4fb6464d43..1272e1b7c6f06e0c75c1218e24b0a404014238ee 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h @@ -211,6 +211,13 @@ class ExecutionTree { // @return Bool - true is ExecutionTree is finished bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; } + // Return if the ExecutionTree is ready. + // @return Bool - true is ExecutionTree is ready + bool isPrepared() const { + return tree_state_ == TreeState::kDeTStateReady || tree_state_ == kDeTStateExecuting || + tree_state_ == kDeTStateFinished; + } + // Set the ExecutionTree to Finished state. void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 7c8e6bf9215d7d7598f6a5134dbf66b31f0f1719..be00552784637fc9ac222eb0546bb008534eefba 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -38,7 +38,7 @@ from mindspore._c_expression import typing from mindspore import log as logger from . import samplers -from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp +from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ check_rename, check_numpyslicesdataset, check_device_send, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ @@ -1203,6 +1203,12 @@ class Dataset: self._repeat_count = device_iter.get_repeat_count() device_iter.stop() + def get_col_names(self): + """ + Get names of the columns in the dataset + """ + return Iterator(self).get_col_names() + def output_shapes(self): """ Get the shapes of output data. diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 2c0835bc2789aa65296fadaca2b2902029e94592..5981bb3daf524c7d4d717143986f99b44744ab23 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -93,7 +93,7 @@ class Iterator: root = self.__convert_node_postorder(self.dataset) self.depipeline.AssignRootNode(root) - self.depipeline.LaunchTreeExec(self.num_epochs) + self.depipeline.PrepareTree(self.num_epochs) self._index = 0 def stop(self): @@ -276,6 +276,9 @@ class Iterator: def num_classes(self): return self.depipeline.GetNumClasses() + def get_col_names(self): + return self.depipeline.GetColumnNames() + def __deepcopy__(self, memo): return self @@ -283,6 +286,10 @@ class SaveOp(Iterator): """ The derived class of Iterator with dict type. """ + def __init__(self, dataset, num_epochs=-1): + super().__init__(dataset, num_epochs) + self.depipeline.LaunchTreeExec() + def get_next(self): pass @@ -298,6 +305,10 @@ class DictIterator(Iterator): """ The derived class of Iterator with dict type. """ + def __init__(self, dataset, num_epochs=-1): + super().__init__(dataset, num_epochs) + self.depipeline.LaunchTreeExec() + def check_node_type(self, node): pass @@ -328,6 +339,7 @@ class TupleIterator(Iterator): columns = [columns] dataset = dataset.project(columns) super().__init__(dataset, num_epochs) + self.depipeline.LaunchTreeExec() def __iter__(self): return self diff --git a/tests/ut/python/dataset/test_get_col_names.py b/tests/ut/python/dataset/test_get_col_names.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b4e210ffe1a94b87b7190f24a75ff6fb39b404 --- /dev/null +++ b/tests/ut/python/dataset/test_get_col_names.py @@ -0,0 +1,198 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np + +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision + +CELEBA_DIR = "../data/dataset/testCelebAData" +CIFAR10_DIR = "../data/dataset/testCifar10Data" +CIFAR100_DIR = "../data/dataset/testCifar100Data" +CLUE_DIR = "../data/dataset/testCLUE/afqmc/train.json" +COCO_DIR = "../data/dataset/testCOCO/train" +COCO_ANNOTATION = "../data/dataset/testCOCO/annotations/train.json" +CSV_DIR = "../data/dataset/testCSV/1.csv" +IMAGE_FOLDER_DIR = "../data/dataset/testPK/data/" +MANIFEST_DIR = "../data/dataset/testManifestData/test.manifest" +MNIST_DIR = "../data/dataset/testMnistData" +TFRECORD_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] +TFRECORD_SCHEMA = "../data/dataset/testTFTestAllTypes/datasetSchema.json" +VOC_DIR = "../data/dataset/testVOC2012" + + +def test_get_column_name_celeba(): + data = ds.CelebADataset(CELEBA_DIR) + assert data.get_col_names() == ["image", "attr"] + + +def test_get_column_name_cifar10(): + data = ds.Cifar10Dataset(CIFAR10_DIR) + assert data.get_col_names() == ["image", "label"] + + +def test_get_column_name_cifar100(): + data = ds.Cifar100Dataset(CIFAR100_DIR) + assert data.get_col_names() == ["image", "coarse_label", "fine_label"] + + +def test_get_column_name_clue(): + data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train") + assert data.get_col_names() == ["label", "sentence1", "sentence2"] + + +def test_get_column_name_coco(): + data = ds.CocoDataset(COCO_DIR, annotation_file=COCO_ANNOTATION, task="Detection", + decode=True, shuffle=False) + assert data.get_col_names() == ["image", "bbox", "category_id", "iscrowd"] + + +def test_get_column_name_csv(): + data = ds.CSVDataset(CSV_DIR) + assert data.get_col_names() == ["1", "2", "3", "4"] + data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"]) + assert data.get_col_names() == ["col1", "col2", "col3", "col4"] + + +def test_get_column_name_generator(): + def generator(): + for i in range(64): + yield (np.array([i]),) + + data = ds.GeneratorDataset(generator, ["data"]) + assert data.get_col_names() == ["data"] + + +def test_get_column_name_imagefolder(): + data = ds.ImageFolderDatasetV2(IMAGE_FOLDER_DIR) + assert data.get_col_names() == ["image", "label"] + + +def test_get_column_name_iterator(): + data = ds.Cifar10Dataset(CIFAR10_DIR) + itr = data.create_tuple_iterator(num_epochs=1) + assert itr.get_col_names() == ["image", "label"] + itr = data.create_dict_iterator(num_epochs=1) + assert itr.get_col_names() == ["image", "label"] + + +def test_get_column_name_manifest(): + data = ds.ManifestDataset(MANIFEST_DIR) + assert data.get_col_names() == ["image", "label"] + + +def test_get_column_name_map(): + data = ds.Cifar10Dataset(CIFAR10_DIR) + center_crop_op = vision.CenterCrop(10) + data = data.map(input_columns=["image"], operations=center_crop_op) + assert data.get_col_names() == ["image", "label"] + data = ds.Cifar10Dataset(CIFAR10_DIR) + data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["image"]) + assert data.get_col_names() == ["image", "label"] + data = ds.Cifar10Dataset(CIFAR10_DIR) + data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1"]) + assert data.get_col_names() == ["col1", "label"] + data = ds.Cifar10Dataset(CIFAR10_DIR) + data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1", "col2"], + columns_order=["col2", "col1"]) + assert data.get_col_names() == ["col2", "col1"] + + +def test_get_column_name_mnist(): + data = ds.MnistDataset(MNIST_DIR) + assert data.get_col_names() == ["image", "label"] + + +def test_get_column_name_numpy_slices(): + np_data = {"a": [1, 2], "b": [3, 4]} + data = ds.NumpySlicesDataset(np_data, shuffle=False) + assert data.get_col_names() == ["a", "b"] + data = ds.NumpySlicesDataset([1, 2, 3], shuffle=False) + assert data.get_col_names() == ["column_0"] + + +def test_get_column_name_tfrecord(): + data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA) + assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32", + "col_sint64"] + data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA, + columns_list=["col_sint16", "col_sint64", "col_2d", "col_binary"]) + assert data.get_col_names() == ["col_sint16", "col_sint64", "col_2d", "col_binary"] + + data = ds.TFRecordDataset(TFRECORD_DIR) + assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32", + "col_sint64", "col_sint8"] + s = ds.Schema() + s.add_column("line", "string", []) + s.add_column("words", "string", [-1]) + s.add_column("chinese", "string", []) + + data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s) + assert data.get_col_names() == ["line", "words", "chinese"] + + +def test_get_column_name_to_device(): + data = ds.Cifar10Dataset(CIFAR10_DIR) + data = data.to_device() + assert data.get_col_names() == ["image", "label"] + + +def test_get_column_name_voc(): + data = ds.VOCDataset(VOC_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) + assert data.get_col_names() == ["image", "target"] + + +def test_get_column_name_project(): + data = ds.Cifar10Dataset(CIFAR10_DIR) + assert data.get_col_names() == ["image", "label"] + data = data.project(columns=["image"]) + assert data.get_col_names() == ["image"] + + +def test_get_column_name_rename(): + data = ds.Cifar10Dataset(CIFAR10_DIR) + assert data.get_col_names() == ["image", "label"] + data = data.rename(["image", "label"], ["test1", "test2"]) + assert data.get_col_names() == ["test1", "test2"] + + +def test_get_column_name_zip(): + data1 = ds.Cifar10Dataset(CIFAR10_DIR) + assert data1.get_col_names() == ["image", "label"] + data2 = ds.CSVDataset(CSV_DIR) + assert data2.get_col_names() == ["1", "2", "3", "4"] + data = ds.zip((data1, data2)) + assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"] + + +if __name__ == "__main__": + test_get_column_name_celeba() + test_get_column_name_cifar10() + test_get_column_name_cifar100() + test_get_column_name_clue() + test_get_column_name_coco() + test_get_column_name_csv() + test_get_column_name_generator() + test_get_column_name_imagefolder() + test_get_column_name_iterator() + test_get_column_name_manifest() + test_get_column_name_map() + test_get_column_name_mnist() + test_get_column_name_numpy_slices() + test_get_column_name_tfrecord() + test_get_column_name_to_device() + test_get_column_name_voc() + test_get_column_name_project() + test_get_column_name_rename() + test_get_column_name_zip()