/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "minddata/dataset/api/python/pybind_register.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "pybind11/stl_bind.h" #include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/clue_op.h" #include "minddata/dataset/engine/datasetops/source/csv_op.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/manifest_op.h" #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" #include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h" #include "minddata/dataset/engine/datasetops/source/text_file_op.h" #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" #include "minddata/dataset/engine/datasetops/source/voc_op.h" namespace mindspore { namespace dataset { PYBIND_REGISTER(CifarOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "CifarOp") .def_static("get_num_rows", [](const std::string &dir, const std::string &usage, bool isCifar10) { int64_t count = 0; THROW_IF_ERROR(CifarOp::CountTotalRows(dir, usage, isCifar10, &count)); return count; }); })); PYBIND_REGISTER(ClueOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "ClueOp") .def_static("get_num_rows", [](const py::list &files) { int64_t count = 0; std::vector filenames; for (auto file : files) { file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); } THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count)); return count; }); })); PYBIND_REGISTER(CsvOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "CsvOp") .def_static("get_num_rows", [](const py::list &files, bool csv_header) { int64_t count = 0; std::vector filenames; for (auto file : files) { file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); } THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count)); return count; }); })); PYBIND_REGISTER(CocoOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "CocoOp") .def_static("get_class_indexing", [](const std::string &dir, const std::string &file, const std::string &task) { std::vector>> output_class_indexing; THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing)); return output_class_indexing; }) .def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) { int64_t count = 0; THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count)); return count; }); })); PYBIND_REGISTER(ImageFolderOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "ImageFolderOp") .def_static("get_num_rows_and_classes", [](const std::string &path) { int64_t count = 0, num_classes = 0; THROW_IF_ERROR( ImageFolderOp::CountRowsAndClasses(path, std::set{}, &count, &num_classes)); return py::make_tuple(count, num_classes); }); })); PYBIND_REGISTER(ManifestOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "ManifestOp") .def_static("get_num_rows_and_classes", [](const std::string &file, const py::dict &dict, const std::string &usage) { int64_t count = 0, num_classes = 0; THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); return py::make_tuple(count, num_classes); }) .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) { std::map output_class_indexing; THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); return output_class_indexing; }); })); PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "MindRecordOp") .def_static("get_num_rows", [](const std::vector &paths, bool load_dataset, const py::object &sampler, const int64_t num_padded) { int64_t count = 0; std::shared_ptr op; if (py::hasattr(sampler, "create_for_minddataset")) { auto create = sampler.attr("create_for_minddataset"); op = create().cast>(); } THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); return count; }); })); PYBIND_REGISTER(MnistOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "MnistOp") .def_static("get_num_rows", [](const std::string &dir, const std::string &usage) { int64_t count = 0; THROW_IF_ERROR(MnistOp::CountTotalRows(dir, usage, &count)); return count; }); })); PYBIND_REGISTER(TextFileOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "TextFileOp") .def_static("get_num_rows", [](const py::list &files) { int64_t count = 0; std::vector filenames; for (auto file : files) { !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); } THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); return count; }); })); PYBIND_REGISTER(TFReaderOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "TFReaderOp") .def_static( "get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) { int64_t count = 0; std::vector filenames; for (auto l : files) { !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back(""); } THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate)); return count; }); })); PYBIND_REGISTER(VOCOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "VOCOp") .def_static("get_num_rows", [](const std::string &dir, const std::string &task_type, const std::string &task_mode, const py::dict &dict, int64_t numSamples) { int64_t count = 0; THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); return count; }) .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, const std::string &task_mode, const py::dict &dict) { std::map output_class_indexing; THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); return output_class_indexing; }); })); } // namespace dataset } // namespace mindspore