提交 d76ac7c6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5648 GetColumnNames for Python

Merge pull request !5648 from MahdiRahmaniHanzaki/get-col-name
......@@ -44,7 +44,14 @@ PYBIND_REGISTER(
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
.def("SetBatchParameters",
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
.def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); })
.def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); })
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
.def("GetColumnNames",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetColumnNames(&out));
return out;
})
.def("GetNextAsMap",
[](DEPipeline &de) {
py::dict out;
......
......@@ -172,9 +172,11 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &
// Function to assign the node as root.
Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); }
// Function to prepare the tree
Status DEPipeline::PrepareTree(const int32_t num_epochs) { return tree_->Prepare(num_epochs); }
// Function to launch the tree execution.
Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) {
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs));
Status DEPipeline::LaunchTreeExec() {
RETURN_IF_NOT_OK(tree_->Launch());
iterator_ = std::make_unique<DatasetIterator>(tree_);
if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator.");
......@@ -189,6 +191,25 @@ void DEPipeline::PrintTree() {
}
}
Status DEPipeline::GetColumnNames(py::list *output) {
if (!tree_->isPrepared()) {
RETURN_STATUS_UNEXPECTED("GetColumnNames: Make sure to call prepare before calling GetColumnNames.");
}
std::unordered_map<std::string, int32_t> column_name_id_map = tree_->root()->column_name_id_map();
if (column_name_id_map.empty())
RETURN_STATUS_UNEXPECTED("GetColumnNames: Column names was empty. Make sure Prepare is called.");
std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(),
column_name_id_map.end());
std::sort(column_name_id_vector.begin(), column_name_id_vector.end(),
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
return a.second < b.second;
});
for (auto item : column_name_id_vector) {
(*output).append(item.first);
}
return Status::OK();
}
Status DEPipeline::GetNextAsMap(py::dict *output) {
TensorMap row;
Status s;
......
......@@ -92,8 +92,14 @@ class DEPipeline {
// Function to assign the node as root.
Status AssignRootNode(const DsOpPtr &dataset_op);
// Function to get the column names in the last node in the tree in order
Status GetColumnNames(py::list *output);
// Function to prepare the tree for execution
Status PrepareTree(const int32_t num_epochs);
// Function to launch the tree execution.
Status LaunchTreeExec(int32_t num_epochs);
Status LaunchTreeExec();
// Get a row of data as dictionary of column name to the value.
Status GetNextAsMap(py::dict *output);
......
......@@ -83,7 +83,8 @@ void GeneratorOp::Dealloc() noexcept {
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
// GC the generator object within GIL
(void)generator_.dec_ref();
if (generator_function_.ref_count() == 1) generator_function_.dec_ref();
if (generator_.ref_count() == 1) (void)generator_.dec_ref();
// Release GIL
PyGILState_Release(gstate);
}
......
......@@ -211,6 +211,13 @@ class ExecutionTree {
// @return Bool - true is ExecutionTree is finished
bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; }
// Return if the ExecutionTree is ready.
// @return Bool - true is ExecutionTree is ready
bool isPrepared() const {
return tree_state_ == TreeState::kDeTStateReady || tree_state_ == kDeTStateExecuting ||
tree_state_ == kDeTStateFinished;
}
// Set the ExecutionTree to Finished state.
void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; }
......
......@@ -38,7 +38,7 @@ from mindspore._c_expression import typing
from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, check_device_send, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
......@@ -1203,6 +1203,12 @@ class Dataset:
self._repeat_count = device_iter.get_repeat_count()
device_iter.stop()
def get_col_names(self):
"""
Get names of the columns in the dataset
"""
return Iterator(self).get_col_names()
def output_shapes(self):
"""
Get the shapes of output data.
......
......@@ -93,7 +93,7 @@ class Iterator:
root = self.__convert_node_postorder(self.dataset)
self.depipeline.AssignRootNode(root)
self.depipeline.LaunchTreeExec(self.num_epochs)
self.depipeline.PrepareTree(self.num_epochs)
self._index = 0
def stop(self):
......@@ -276,6 +276,9 @@ class Iterator:
def num_classes(self):
return self.depipeline.GetNumClasses()
def get_col_names(self):
return self.depipeline.GetColumnNames()
def __deepcopy__(self, memo):
return self
......@@ -283,6 +286,10 @@ class SaveOp(Iterator):
"""
The derived class of Iterator with dict type.
"""
def __init__(self, dataset, num_epochs=-1):
super().__init__(dataset, num_epochs)
self.depipeline.LaunchTreeExec()
def get_next(self):
pass
......@@ -298,6 +305,10 @@ class DictIterator(Iterator):
"""
The derived class of Iterator with dict type.
"""
def __init__(self, dataset, num_epochs=-1):
super().__init__(dataset, num_epochs)
self.depipeline.LaunchTreeExec()
def check_node_type(self, node):
pass
......@@ -328,6 +339,7 @@ class TupleIterator(Iterator):
columns = [columns]
dataset = dataset.project(columns)
super().__init__(dataset, num_epochs)
self.depipeline.LaunchTreeExec()
def __iter__(self):
return self
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
CELEBA_DIR = "../data/dataset/testCelebAData"
CIFAR10_DIR = "../data/dataset/testCifar10Data"
CIFAR100_DIR = "../data/dataset/testCifar100Data"
CLUE_DIR = "../data/dataset/testCLUE/afqmc/train.json"
COCO_DIR = "../data/dataset/testCOCO/train"
COCO_ANNOTATION = "../data/dataset/testCOCO/annotations/train.json"
CSV_DIR = "../data/dataset/testCSV/1.csv"
IMAGE_FOLDER_DIR = "../data/dataset/testPK/data/"
MANIFEST_DIR = "../data/dataset/testManifestData/test.manifest"
MNIST_DIR = "../data/dataset/testMnistData"
TFRECORD_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
TFRECORD_SCHEMA = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
VOC_DIR = "../data/dataset/testVOC2012"
def test_get_column_name_celeba():
data = ds.CelebADataset(CELEBA_DIR)
assert data.get_col_names() == ["image", "attr"]
def test_get_column_name_cifar10():
data = ds.Cifar10Dataset(CIFAR10_DIR)
assert data.get_col_names() == ["image", "label"]
def test_get_column_name_cifar100():
data = ds.Cifar100Dataset(CIFAR100_DIR)
assert data.get_col_names() == ["image", "coarse_label", "fine_label"]
def test_get_column_name_clue():
data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train")
assert data.get_col_names() == ["label", "sentence1", "sentence2"]
def test_get_column_name_coco():
data = ds.CocoDataset(COCO_DIR, annotation_file=COCO_ANNOTATION, task="Detection",
decode=True, shuffle=False)
assert data.get_col_names() == ["image", "bbox", "category_id", "iscrowd"]
def test_get_column_name_csv():
data = ds.CSVDataset(CSV_DIR)
assert data.get_col_names() == ["1", "2", "3", "4"]
data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"])
assert data.get_col_names() == ["col1", "col2", "col3", "col4"]
def test_get_column_name_generator():
def generator():
for i in range(64):
yield (np.array([i]),)
data = ds.GeneratorDataset(generator, ["data"])
assert data.get_col_names() == ["data"]
def test_get_column_name_imagefolder():
data = ds.ImageFolderDatasetV2(IMAGE_FOLDER_DIR)
assert data.get_col_names() == ["image", "label"]
def test_get_column_name_iterator():
data = ds.Cifar10Dataset(CIFAR10_DIR)
itr = data.create_tuple_iterator(num_epochs=1)
assert itr.get_col_names() == ["image", "label"]
itr = data.create_dict_iterator(num_epochs=1)
assert itr.get_col_names() == ["image", "label"]
def test_get_column_name_manifest():
data = ds.ManifestDataset(MANIFEST_DIR)
assert data.get_col_names() == ["image", "label"]
def test_get_column_name_map():
data = ds.Cifar10Dataset(CIFAR10_DIR)
center_crop_op = vision.CenterCrop(10)
data = data.map(input_columns=["image"], operations=center_crop_op)
assert data.get_col_names() == ["image", "label"]
data = ds.Cifar10Dataset(CIFAR10_DIR)
data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["image"])
assert data.get_col_names() == ["image", "label"]
data = ds.Cifar10Dataset(CIFAR10_DIR)
data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1"])
assert data.get_col_names() == ["col1", "label"]
data = ds.Cifar10Dataset(CIFAR10_DIR)
data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1", "col2"],
columns_order=["col2", "col1"])
assert data.get_col_names() == ["col2", "col1"]
def test_get_column_name_mnist():
data = ds.MnistDataset(MNIST_DIR)
assert data.get_col_names() == ["image", "label"]
def test_get_column_name_numpy_slices():
np_data = {"a": [1, 2], "b": [3, 4]}
data = ds.NumpySlicesDataset(np_data, shuffle=False)
assert data.get_col_names() == ["a", "b"]
data = ds.NumpySlicesDataset([1, 2, 3], shuffle=False)
assert data.get_col_names() == ["column_0"]
def test_get_column_name_tfrecord():
data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA)
assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
"col_sint64"]
data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA,
columns_list=["col_sint16", "col_sint64", "col_2d", "col_binary"])
assert data.get_col_names() == ["col_sint16", "col_sint64", "col_2d", "col_binary"]
data = ds.TFRecordDataset(TFRECORD_DIR)
assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
"col_sint64", "col_sint8"]
s = ds.Schema()
s.add_column("line", "string", [])
s.add_column("words", "string", [-1])
s.add_column("chinese", "string", [])
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
assert data.get_col_names() == ["line", "words", "chinese"]
def test_get_column_name_to_device():
data = ds.Cifar10Dataset(CIFAR10_DIR)
data = data.to_device()
assert data.get_col_names() == ["image", "label"]
def test_get_column_name_voc():
data = ds.VOCDataset(VOC_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
assert data.get_col_names() == ["image", "target"]
def test_get_column_name_project():
data = ds.Cifar10Dataset(CIFAR10_DIR)
assert data.get_col_names() == ["image", "label"]
data = data.project(columns=["image"])
assert data.get_col_names() == ["image"]
def test_get_column_name_rename():
data = ds.Cifar10Dataset(CIFAR10_DIR)
assert data.get_col_names() == ["image", "label"]
data = data.rename(["image", "label"], ["test1", "test2"])
assert data.get_col_names() == ["test1", "test2"]
def test_get_column_name_zip():
data1 = ds.Cifar10Dataset(CIFAR10_DIR)
assert data1.get_col_names() == ["image", "label"]
data2 = ds.CSVDataset(CSV_DIR)
assert data2.get_col_names() == ["1", "2", "3", "4"]
data = ds.zip((data1, data2))
assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"]
if __name__ == "__main__":
test_get_column_name_celeba()
test_get_column_name_cifar10()
test_get_column_name_cifar100()
test_get_column_name_clue()
test_get_column_name_coco()
test_get_column_name_csv()
test_get_column_name_generator()
test_get_column_name_imagefolder()
test_get_column_name_iterator()
test_get_column_name_manifest()
test_get_column_name_map()
test_get_column_name_mnist()
test_get_column_name_numpy_slices()
test_get_column_name_tfrecord()
test_get_column_name_to_device()
test_get_column_name_voc()
test_get_column_name_project()
test_get_column_name_rename()
test_get_column_name_zip()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册