提交 82103a69 编写于 作者: N nhussain

restructuring

上级 7ec0b585
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if (ENABLE_PYTHON)
add_library(APItoPython OBJECT
de_pipeline.cc
python_bindings.cc
)
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
endif()
add_library(APItoPython OBJECT
python/de_pipeline.cc
python/pybind_register.cc
python/bindings.cc
python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/core/bindings.cc
python/bindings/dataset/kernels/data/bindings.cc
python/bindings/dataset/kernels/bindings.cc
python/bindings/dataset/engine/datasetops/bindings.cc
python/bindings/dataset/engine/datasetops/source/bindings.cc
python/bindings/dataset/engine/gnn/bindings.cc
python/bindings/dataset/kernels/image/bindings.cc
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
python/bindings/dataset/text/bindings.cc
python/bindings/dataset/text/kernels/bindings.cc
python/bindings/mindrecord/include/bindings.cc
)
target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS})
endif ()
add_library(cpp-API OBJECT
datasets.cc
iterator.cc
transforms.cc
samplers.cc
)
datasets.cc
iterator.cc
transforms.cc
samplers.cc
)
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/api/python/de_pipeline.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
DEPipeline, 0, ([](const py::module *m) {
(void)py::class_<DEPipeline>(*m, "DEPipeline")
.def(py::init<>())
.def(
"AddNodeToTree",
[](DEPipeline &de, const OpName &op_name, const py::dict &args) {
py::dict out;
THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out));
return out;
},
py::return_value_policy::reference)
.def_static("AddChildToParentNode",
[](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
})
.def("AssignRootNode",
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
.def("SetBatchParameters",
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
.def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); })
.def("GetNextAsMap",
[](DEPipeline &de) {
py::dict out;
THROW_IF_ERROR(de.GetNextAsMap(&out));
return out;
})
.def("GetNextAsList",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetNextAsList(&out));
return out;
})
.def("GetOutputShapes",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetOutputShapes(&out));
return out;
})
.def("GetOutputTypes",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetOutputTypes(&out));
return out;
})
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
.def("GetBatchSize", &DEPipeline::GetBatchSize)
.def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true;
});
}));
PYBIND_REGISTER(OpName, 0, ([](const py::module *m) {
(void)py::enum_<OpName>(*m, "OpName", py::arithmetic())
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BUCKETBATCH", OpName::kBucketBatch)
.value("BARRIER", OpName::kBarrier)
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)
.value("REPEAT", OpName::kRepeat)
.value("SKIP", OpName::kSkip)
.value("TAKE", OpName::kTake)
.value("ZIP", OpName::kZip)
.value("CONCAT", OpName::kConcat)
.value("MAP", OpName::kMap)
.value("FILTER", OpName::kFilter)
.value("DEVICEQUEUE", OpName::kDeviceQueue)
.value("GENERATOR", OpName::kGenerator)
.export_values()
.value("RENAME", OpName::kRename)
.value("TFREADER", OpName::kTfReader)
.value("PROJECT", OpName::kProject)
.value("IMAGEFOLDER", OpName::kImageFolder)
.value("MNIST", OpName::kMnist)
.value("MANIFEST", OpName::kManifest)
.value("VOC", OpName::kVoc)
.value("COCO", OpName::kCoco)
.value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100)
.value("RANDOMDATA", OpName::kRandomData)
.value("BUILDVOCAB", OpName::kBuildVocab)
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile)
.value("EPOCHCTRL", OpName::kEpochCtrl)
.value("CSV", OpName::kCsv)
.value("CLUE", OpName::kClue);
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/api/python/de_pipeline.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(GlobalContext, 0, ([](const py::module *m) {
(void)py::class_<GlobalContext>(*m, "GlobalContext")
.def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference);
}));
PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
(void)py::class_<ConfigManager, std::shared_ptr<ConfigManager>>(*m, "ConfigManager")
.def("__str__", &ConfigManager::ToString)
.def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
.def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
.def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
.def("set_op_connector_size", &ConfigManager::set_op_connector_size)
.def("set_seed", &ConfigManager::set_seed)
.def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval)
.def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
.def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
.def("get_worker_connector_size", &ConfigManager::worker_connector_size)
.def("get_op_connector_size", &ConfigManager::op_connector_size)
.def("get_seed", &ConfigManager::seed)
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
}));
PYBIND_REGISTER(Tensor, 0, ([](const py::module *m) {
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
.def(py::init([](py::array arr) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(Tensor::CreateFromNpArray(arr, &out));
return out;
}))
.def_buffer([](Tensor &tensor) {
py::buffer_info info;
THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info));
return info;
})
.def("__str__", &Tensor::ToString)
.def("shape", &Tensor::shape)
.def("type", &Tensor::type)
.def("as_array", [](py::object &t) {
auto &tensor = py::cast<Tensor &>(t);
if (tensor.type() == DataType::DE_STRING) {
py::array res;
tensor.GetDataAsNumpyStrings(&res);
return res;
}
py::buffer_info info;
THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info));
return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t);
});
}));
PYBIND_REGISTER(TensorShape, 0, ([](const py::module *m) {
(void)py::class_<TensorShape>(*m, "TensorShape")
.def(py::init<py::list>())
.def("__str__", &TensorShape::ToString)
.def("as_list", &TensorShape::AsPyList)
.def("is_known", &TensorShape::known);
}));
PYBIND_REGISTER(DataType, 0, ([](const py::module *m) {
(void)py::class_<DataType>(*m, "DataType")
.def(py::init<std::string>())
.def(py::self == py::self)
.def("__str__", &DataType::ToString)
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
}));
PYBIND_REGISTER(BorderType, 0, ([](const py::module *m) {
(void)py::enum_<BorderType>(*m, "BorderType", py::arithmetic())
.value("DE_BORDER_CONSTANT", BorderType::kConstant)
.value("DE_BORDER_EDGE", BorderType::kEdge)
.value("DE_BORDER_REFLECT", BorderType::kReflect)
.value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric)
.export_values();
}));
PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) {
(void)py::enum_<InterpolationMode>(*m, "InterpolationMode", py::arithmetic())
.value("DE_INTER_LINEAR", InterpolationMode::kLinear)
.value("DE_INTER_CUBIC", InterpolationMode::kCubic)
.value("DE_INTER_AREA", InterpolationMode::kArea)
.value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour)
.export_values();
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/cache/cache_client.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def(py::init<uint32_t, uint64_t, bool>());
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/datasetops/batch_op.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(CBatchInfo, 0, ([](const py::module *m) {
(void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
.def(py::init<int64_t, int64_t, int64_t>())
.def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
}));
PYBIND_REGISTER(DatasetOp, 0, ([](const py::module *m) {
(void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(*m, "DatasetOp");
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_register.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(CifarOp, 1, ([](const py::module *m) {
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
int64_t count = 0;
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
return count;
});
}));
PYBIND_REGISTER(ClueOp, 1, ([](const py::module *m) {
(void)py::class_<ClueOp, DatasetOp, std::shared_ptr<ClueOp>>(*m, "ClueOp")
.def_static("get_num_rows", [](const py::list &files) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
}
THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count));
return count;
});
}));
PYBIND_REGISTER(CsvOp, 1, ([](const py::module *m) {
(void)py::class_<CsvOp, DatasetOp, std::shared_ptr<CsvOp>>(*m, "CsvOp")
.def_static("get_num_rows", [](const py::list &files, bool csv_header) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
}
THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count));
return count;
});
}));
PYBIND_REGISTER(CocoOp, 1, ([](const py::module *m) {
(void)py::class_<CocoOp, DatasetOp, std::shared_ptr<CocoOp>>(*m, "CocoOp")
.def_static("get_class_indexing",
[](const std::string &dir, const std::string &file, const std::string &task) {
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing));
return output_class_indexing;
})
.def_static("get_num_rows",
[](const std::string &dir, const std::string &file, const std::string &task) {
int64_t count = 0;
THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count));
return count;
});
}));
PYBIND_REGISTER(ImageFolderOp, 1, ([](const py::module *m) {
(void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
.def_static("get_num_rows_and_classes", [](const std::string &path) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(
ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
return py::make_tuple(count, num_classes);
});
}));
PYBIND_REGISTER(ManifestOp, 1, ([](const py::module *m) {
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
.def_static("get_num_rows_and_classes",
[](const std::string &file, const py::dict &dict, const std::string &usage) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
return py::make_tuple(count, num_classes);
})
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict,
const std::string &usage) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
return output_class_indexing;
});
}));
PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) {
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
.def_static("get_num_rows", [](const std::vector<std::string> &paths, bool load_dataset,
const py::object &sampler, const int64_t num_padded) {
int64_t count = 0;
std::shared_ptr<mindrecord::ShardOperator> op;
if (py::hasattr(sampler, "create_for_minddataset")) {
auto create = sampler.attr("create_for_minddataset");
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
}
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded));
return count;
});
}));
PYBIND_REGISTER(MnistOp, 1, ([](const py::module *m) {
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
.def_static("get_num_rows", [](const std::string &dir) {
int64_t count = 0;
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
return count;
});
}));
PYBIND_REGISTER(TextFileOp, 1, ([](const py::module *m) {
(void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp")
.def_static("get_num_rows", [](const py::list &files) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
!file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back("");
}
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
return count;
});
}));
PYBIND_REGISTER(TFReaderOp, 1, ([](const py::module *m) {
(void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")
.def_static(
"get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto l : files) {
!l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back("");
}
THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate));
return count;
});
}));
PYBIND_REGISTER(VOCOp, 1, ([](const py::module *m) {
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples) {
int64_t count = 0;
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
return count;
})
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
const std::string &task_mode, const py::dict &dict) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
return output_class_indexing;
});
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) {
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
.def("set_num_rows",
[](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples",
[](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices",
[](Sampler &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
})
.def("add_child", [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) {
THROW_IF_ERROR(self->AddChild(child));
});
}));
PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) {
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(
*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
}));
PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) {
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
.def(py::init<int64_t, int64_t, bool>());
}));
PYBIND_REGISTER(PythonSampler, 1, ([](const py::module *m) {
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
.def(py::init<int64_t, py::object>());
}));
PYBIND_REGISTER(RandomSampler, 1, ([](const py::module *m) {
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
.def(py::init<int64_t, bool, bool>());
}));
PYBIND_REGISTER(SequentialSampler, 1, ([](const py::module *m) {
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m,
"SequentialSampler")
.def(py::init<int64_t, int64_t>());
}));
PYBIND_REGISTER(SubsetRandomSampler, 1, ([](const py::module *m) {
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(
*m, "SubsetRandomSampler")
.def(py::init<int64_t, std::vector<int64_t>>());
}));
PYBIND_REGISTER(WeightedRandomSampler, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(
*m, "WeightedRandomSampler")
.def(py::init<int64_t, std::vector<double>, bool>());
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/gnn/graph.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
Graph, 0, ([](const py::module *m) {
(void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
.def(py::init([](std::string dataset_file, int32_t num_workers) {
std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
THROW_IF_ERROR(g_out->Init());
return g_out;
}))
.def("get_all_nodes",
[](gnn::Graph &g, gnn::NodeType node_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
return out;
})
.def("get_all_edges",
[](gnn::Graph &g, gnn::EdgeType edge_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
return out;
})
.def("get_nodes_from_edges",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
return out;
})
.def("get_all_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
return out;
})
.def("get_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
std::vector<gnn::NodeType> neighbor_types) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
return out;
})
.def("get_neg_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
gnn::NodeType neg_neighbor_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
return out;
})
.def("get_node_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out;
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out.getRow();
})
.def("get_edge_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out;
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
return out.getRow();
})
.def("graph_info",
[](gnn::Graph &g) {
py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out));
return out;
})
.def("random_walk",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
return out;
});
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/api/python/de_pipeline.h"
#include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h"
#include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h"
#include "minddata/dataset/kernels/py_func_op.h"
#include "mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.h"
#include "mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.h"
namespace mindspore {
namespace dataset {
Status PyListToTensorOps(const py::list &py_ops, std::vector<std::shared_ptr<TensorOp>> *ops) {
RETURN_UNEXPECTED_IF_NULL(ops);
for (auto op : py_ops) {
if (py::isinstance<TensorOp>(op)) {
ops->emplace_back(op.cast<std::shared_ptr<TensorOp>>());
} else if (py::isinstance<py::function>(op)) {
ops->emplace_back(std::make_shared<PyFuncOp>(op.cast<py::function>()));
} else {
RETURN_STATUS_UNEXPECTED("element is neither a TensorOp nor a pyfunc.");
}
}
CHECK_FAIL_RETURN_UNEXPECTED(!ops->empty(), "TensorOp list is empty.");
for (auto const &op : *ops) {
RETURN_UNEXPECTED_IF_NULL(op);
}
return Status::OK();
}
PYBIND_REGISTER(TensorOp, 0, ([](const py::module *m) {
(void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
}));
PYBIND_REGISTER(ComposeOp, 1, ([](const py::module *m) {
(void)py::class_<ComposeOp, TensorOp, std::shared_ptr<ComposeOp>>(*m, "ComposeOp")
.def(py::init([](const py::list &ops) {
std::vector<std::shared_ptr<TensorOp>> t_ops;
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
return std::make_shared<ComposeOp>(t_ops);
}));
}));
PYBIND_REGISTER(NoOp, 1, ([](const py::module *m) {
(void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(
*m, "NoOp", "TensorOp that does nothing, for testing purposes only.")
.def(py::init<>());
}));
PYBIND_REGISTER(RandomChoiceOp, 1, ([](const py::module *m) {
(void)py::class_<RandomChoiceOp, TensorOp, std::shared_ptr<RandomChoiceOp>>(*m, "RandomChoiceOp")
.def(py::init([](const py::list &ops) {
std::vector<std::shared_ptr<TensorOp>> t_ops;
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
return std::make_shared<RandomChoiceOp>(t_ops);
}));
}));
PYBIND_REGISTER(RandomApplyOp, 1, ([](const py::module *m) {
(void)py::class_<RandomApplyOp, TensorOp, std::shared_ptr<RandomApplyOp>>(*m, "RandomApplyOp")
.def(py::init([](double prob, const py::list &ops) {
std::vector<std::shared_ptr<TensorOp>> t_ops;
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
if (prob < 0 || prob > 1) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be within [0,1]."));
}
return std::make_shared<RandomApplyOp>(prob, t_ops);
}));
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/kernels/data/concatenate_op.h"
#include "minddata/dataset/kernels/data/duplicate_op.h"
#include "minddata/dataset/kernels/data/fill_op.h"
#include "minddata/dataset/kernels/data/mask_op.h"
#include "minddata/dataset/kernels/data/one_hot_op.h"
#include "minddata/dataset/kernels/data/pad_end_op.h"
#include "minddata/dataset/kernels/data/slice_op.h"
#include "minddata/dataset/kernels/data/to_float16_op.h"
#include "minddata/dataset/kernels/data/type_cast_op.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(ConcatenateOp, 1, ([](const py::module *m) {
(void)py::class_<ConcatenateOp, TensorOp, std::shared_ptr<ConcatenateOp>>(
*m, "ConcatenateOp", "Tensor operation concatenate tensors.")
.def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>(), py::arg("axis"),
py::arg("prepend").none(true), py::arg("append").none(true));
}));
PYBIND_REGISTER(DuplicateOp, 1, ([](const py::module *m) {
(void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp",
"Duplicate tensor.")
.def(py::init<>());
}));
PYBIND_REGISTER(FillOp, 1, ([](const py::module *m) {
(void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(
*m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
.def(py::init<std::shared_ptr<Tensor>>());
}));
PYBIND_REGISTER(MaskOp, 1, ([](const py::module *m) {
(void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(
*m, "MaskOp", "Tensor mask operation using relational comparator")
.def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
}));
PYBIND_REGISTER(OneHotOp, 1, ([](const py::module *m) {
(void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(
*m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
.def(py::init<int32_t>());
}));
PYBIND_REGISTER(PadEndOp, 1, ([](const py::module *m) {
(void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(
*m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.")
.def(py::init<TensorShape, std::shared_ptr<Tensor>>());
}));
PYBIND_REGISTER(SliceOp, 1, ([](const py::module *m) {
(void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp",
"Tensor slice operation.")
.def(py::init<bool>())
.def(py::init([](const py::list &py_list) {
std::vector<dsize_t> c_list;
for (auto l : py_list) {
if (!l.is_none()) {
c_list.push_back(py::reinterpret_borrow<py::int_>(l));
}
}
return std::make_shared<SliceOp>(c_list);
}))
.def(py::init([](const py::tuple &py_slice) {
if (py_slice.size() != 3) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
}
Slice c_slice;
if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) {
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]),
py::reinterpret_borrow<py::int_>(py_slice[1]),
py::reinterpret_borrow<py::int_>(py_slice[2]));
} else if (py_slice[0].is_none() && py_slice[2].is_none()) {
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[1]));
} else if (!py_slice[0].is_none() && !py_slice[1].is_none()) {
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]),
py::reinterpret_borrow<py::int_>(py_slice[1]));
}
if (!c_slice.valid()) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
}
return std::make_shared<SliceOp>(c_slice);
}));
}));
PYBIND_REGISTER(ToFloat16Op, 1, ([](const py::module *m) {
(void)py::class_<ToFloat16Op, TensorOp, std::shared_ptr<ToFloat16Op>>(
*m, "ToFloat16Op", py::dynamic_attr(),
"Tensor operator to type cast float32 data to a float16 type.")
.def(py::init<>());
}));
PYBIND_REGISTER(TypeCastOp, 1, ([](const py::module *m) {
(void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(
*m, "TypeCastOp", "Tensor operator to type cast data to a specified type.")
.def(py::init<DataType>(), py::arg("data_type"))
.def(py::init<std::string>(), py::arg("data_type"));
}));
PYBIND_REGISTER(RelationalOp, 0, ([](const py::module *m) {
(void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
.value("EQ", RelationalOp::kEqual)
.value("NE", RelationalOp::kNotEqual)
.value("LT", RelationalOp::kLess)
.value("LE", RelationalOp::kLessEqual)
.value("GT", RelationalOp::kGreater)
.value("GE", RelationalOp::kGreaterEqual)
.export_values();
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/text/vocab.h"
#include "minddata/dataset/text/sentence_piece_vocab.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(Vocab, 0, ([](const py::module *m) {
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
.def(py::init<>())
.def_static("from_list",
[](const py::list &words, const py::list &special_tokens, bool special_first) {
std::shared_ptr<Vocab> v;
THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v));
return v;
})
.def_static(
"from_file",
[](const std::string &path, const std::string &dlm, int32_t vocab_size,
const py::list &special_tokens, bool special_first) {
std::shared_ptr<Vocab> v;
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v));
return v;
})
.def_static("from_dict", [](const py::dict &words) {
std::shared_ptr<Vocab> v;
THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v));
return v;
});
}));
PYBIND_REGISTER(SentencePieceVocab, 0, ([](const py::module *m) {
(void)py::class_<SentencePieceVocab, std::shared_ptr<SentencePieceVocab>>(*m, "SentencePieceVocab")
.def(py::init<>())
.def_static("from_file",
[](const py::list &paths, const int vocab_size, const float character_coverage,
const SentencePieceModel model_type, const py::dict &params) {
std::shared_ptr<SentencePieceVocab> v;
std::vector<std::string> path_list;
for (auto path : paths) {
path_list.emplace_back(py::str(path));
}
std::unordered_map<std::string, std::string> param_map;
for (auto param : params) {
std::string key = py::reinterpret_borrow<py::str>(param.first);
if (key == "input" || key == "vocab_size" || key == "model_prefix" ||
key == "character_coverage" || key == "model_type") {
continue;
}
param_map[key] = py::reinterpret_borrow<py::str>(param.second);
}
THROW_IF_ERROR(SentencePieceVocab::BuildFromFile(
path_list, vocab_size, character_coverage, model_type, param_map, &v));
return v;
})
.def_static("save_model", [](const std::shared_ptr<SentencePieceVocab> *vocab, std::string path,
std::string filename) {
THROW_IF_ERROR(SentencePieceVocab::SaveModel(vocab, path, filename));
});
}));
PYBIND_REGISTER(SentencePieceModel, 0, ([](const py::module *m) {
(void)py::enum_<SentencePieceModel>(*m, "SentencePieceModel", py::arithmetic())
.value("DE_SENTENCE_PIECE_UNIGRAM", SentencePieceModel::kUnigram)
.value("DE_SENTENCE_PIECE_BPE", SentencePieceModel::kBpe)
.value("DE_SENTENCE_PIECE_CHAR", SentencePieceModel::kChar)
.value("DE_SENTENCE_PIECE_WORD", SentencePieceModel::kWord)
.export_values();
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
#include "minddata/dataset/text/kernels/lookup_op.h"
#include "minddata/dataset/text/kernels/ngram_op.h"
#include "minddata/dataset/text/kernels/sliding_window_op.h"
#include "minddata/dataset/text/kernels/to_number_op.h"
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
#include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
#ifdef ENABLE_ICU4C
#include "minddata/dataset/text/kernels/basic_tokenizer_op.h"
#include "minddata/dataset/text/kernels/bert_tokenizer_op.h"
#include "minddata/dataset/text/kernels/case_fold_op.h"
#include "minddata/dataset/text/kernels/normalize_utf8_op.h"
#include "minddata/dataset/text/kernels/regex_replace_op.h"
#include "minddata/dataset/text/kernels/regex_tokenizer_op.h"
#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h"
#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
#endif
namespace mindspore {
namespace dataset {
#ifdef ENABLE_ICU4C
PYBIND_REGISTER(BasicTokenizerOp, 1, ([](const py::module *m) {
(void)py::class_<BasicTokenizerOp, TensorOp, std::shared_ptr<BasicTokenizerOp>>(
*m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.")
.def(py::init<const bool &, const bool &, const NormalizeForm &, const bool &, const bool &>(),
py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken,
py::arg("with_offsets") = BasicTokenizerOp::kDefWithOffsets);
}));
PYBIND_REGISTER(WhitespaceTokenizerOp, 1, ([](const py::module *m) {
(void)py::class_<WhitespaceTokenizerOp, TensorOp, std::shared_ptr<WhitespaceTokenizerOp>>(
*m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.")
.def(py::init<const bool &>(), py::arg(" with_offsets ") = WhitespaceTokenizerOp::kDefWithOffsets);
}));
PYBIND_REGISTER(UnicodeScriptTokenizerOp, 1, ([](const py::module *m) {
(void)py::class_<UnicodeScriptTokenizerOp, TensorOp, std::shared_ptr<UnicodeScriptTokenizerOp>>(
*m, "UnicodeScriptTokenizerOp",
"Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.")
.def(py::init<>())
.def(py::init<const bool &, const bool &>(),
py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace,
py::arg("with_offsets") = UnicodeScriptTokenizerOp::kDefWithOffsets);
}));
PYBIND_REGISTER(CaseFoldOp, 1, ([](const py::module *m) {
(void)py::class_<CaseFoldOp, TensorOp, std::shared_ptr<CaseFoldOp>>(
*m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor")
.def(py::init<>());
}));
PYBIND_REGISTER(NormalizeUTF8Op, 1, ([](const py::module *m) {
(void)py::class_<NormalizeUTF8Op, TensorOp, std::shared_ptr<NormalizeUTF8Op>>(
*m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.")
.def(py::init<>())
.def(py::init<NormalizeForm>(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm);
}));
PYBIND_REGISTER(RegexReplaceOp, 1, ([](const py::module *m) {
(void)py::class_<RegexReplaceOp, TensorOp, std::shared_ptr<RegexReplaceOp>>(
*m, "RegexReplaceOp",
"Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.")
.def(py::init<const std::string &, const std::string &, bool>(), py::arg("pattern"),
py::arg("replace"), py::arg("replace_all"));
}));
PYBIND_REGISTER(RegexTokenizerOp, 1, ([](const py::module *m) {
(void)py::class_<RegexTokenizerOp, TensorOp, std::shared_ptr<RegexTokenizerOp>>(
*m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.")
.def(py::init<const std::string &, const std::string &, const bool &>(), py::arg("delim_pattern"),
py::arg("keep_delim_pattern"), py::arg("with_offsets") = RegexTokenizerOp::kDefWithOffsets);
}));
PYBIND_REGISTER(BertTokenizerOp, 1, ([](const py::module *m) {
(void)py::class_<BertTokenizerOp, TensorOp, std::shared_ptr<BertTokenizerOp>>(
*m, "BertTokenizerOp", "Tokenizer used for Bert text process.")
.def(py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &,
const bool &, const bool &, const NormalizeForm &, const bool &, const bool &>(),
py::arg("vocab"),
py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken,
py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
}));
PYBIND_REGISTER(NormalizeForm, 0, ([](const py::module *m) {
(void)py::enum_<NormalizeForm>(*m, "NormalizeForm", py::arithmetic())
.value("DE_NORMALIZE_NONE", NormalizeForm::kNone)
.value("DE_NORMALIZE_NFC", NormalizeForm::kNfc)
.value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc)
.value("DE_NORMALIZE_NFD", NormalizeForm::kNfd)
.value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd)
.export_values();
}));
#endif
PYBIND_REGISTER(JiebaTokenizerOp, 1, ([](const py::module *m) {
(void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(
*m, "JiebaTokenizerOp", "")
.def(py::init<const std::string &, const std::string &, const JiebaMode &, const bool &>(),
py::arg("hmm_path"), py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix,
py::arg("with_offsets") = JiebaTokenizerOp::kDefWithOffsets)
.def("add_word", [](JiebaTokenizerOp &self, const std::string word, int freq) {
THROW_IF_ERROR(self.AddWord(word, freq));
});
}));
PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) {
(void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>(
*m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.")
.def(py::init<const bool &>(), py::arg("with_offsets") = UnicodeCharTokenizerOp::kDefWithOffsets);
}));
PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
(void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(
*m, "LookupOp", "Tensor operation to LookUp each word.")
.def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word) {
if (vocab == nullptr) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null."));
}
if (py_word.is_none()) {
return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists);
}
std::string word = py::reinterpret_borrow<py::str>(py_word);
WordIdType default_id = vocab->Lookup(word);
if (default_id == Vocab::kNoTokenExists) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError,
"default unknown token: " + word + " doesn't exist in vocab."));
}
return std::make_shared<LookupOp>(vocab, default_id);
}));
}));
PYBIND_REGISTER(NgramOp, 1, ([](const py::module *m) {
(void)py::class_<NgramOp, TensorOp, std::shared_ptr<NgramOp>>(*m, "NgramOp",
"TensorOp performs ngram mapping.")
.def(py::init<const std::vector<int32_t> &, int32_t, int32_t, const std::string &,
const std::string &, const std::string &>(),
py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"),
py::arg("r_pad_token"), py::arg("separator"));
}));
PYBIND_REGISTER(
WordpieceTokenizerOp, 1, ([](const py::module *m) {
(void)py::class_<WordpieceTokenizerOp, TensorOp, std::shared_ptr<WordpieceTokenizerOp>>(
*m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.")
.def(
py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &, const bool &>(),
py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
}));
PYBIND_REGISTER(SlidingWindowOp, 1, ([](const py::module *m) {
(void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>(
*m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.")
.def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis"));
}));
PYBIND_REGISTER(
SentencePieceTokenizerOp, 1, ([](const py::module *m) {
(void)py::class_<SentencePieceTokenizerOp, TensorOp, std::shared_ptr<SentencePieceTokenizerOp>>(
*m, "SentencePieceTokenizerOp", "Tokenize scalar token or 1-D tokens to tokens by sentence piece.")
.def(
py::init<std::shared_ptr<SentencePieceVocab> &, const SPieceTokenizerLoadType, const SPieceTokenizerOutType>(),
py::arg("vocab"), py::arg("load_type") = SPieceTokenizerLoadType::kModel,
py::arg("out_type") = SPieceTokenizerOutType::kString)
.def(py::init<const std::string &, const std::string &, const SPieceTokenizerLoadType,
const SPieceTokenizerOutType>(),
py::arg("model_path"), py::arg("model_filename"), py::arg("load_type") = SPieceTokenizerLoadType::kFile,
py::arg("out_type") = SPieceTokenizerOutType::kString);
}));
PYBIND_REGISTER(ToNumberOp, 1, ([](const py::module *m) {
(void)py::class_<ToNumberOp, TensorOp, std::shared_ptr<ToNumberOp>>(
*m, "ToNumberOp", "TensorOp to convert strings to numbers.")
.def(py::init<DataType>(), py::arg("data_type"))
.def(py::init<std::string>(), py::arg("data_type"));
}));
PYBIND_REGISTER(TruncateSequencePairOp, 1, ([](const py::module *m) {
(void)py::class_<TruncateSequencePairOp, TensorOp, std::shared_ptr<TruncateSequencePairOp>>(
*m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length")
.def(py::init<int64_t>());
}));
PYBIND_REGISTER(JiebaMode, 0, ([](const py::module *m) {
(void)py::enum_<JiebaMode>(*m, "JiebaMode", py::arithmetic())
.value("DE_JIEBA_MIX", JiebaMode::kMix)
.value("DE_JIEBA_MP", JiebaMode::kMp)
.value("DE_JIEBA_HMM", JiebaMode::kHmm)
.export_values();
}));
PYBIND_REGISTER(SPieceTokenizerOutType, 0, ([](const py::module *m) {
(void)py::enum_<SPieceTokenizerOutType>(*m, "SPieceTokenizerOutType", py::arithmetic())
.value("DE_SPIECE_TOKENIZER_OUTTYPE_KString", SPieceTokenizerOutType::kString)
.value("DE_SPIECE_TOKENIZER_OUTTYPE_KINT", SPieceTokenizerOutType::kInt)
.export_values();
}));
PYBIND_REGISTER(SPieceTokenizerLoadType, 0, ([](const py::module *m) {
(void)py::enum_<SPieceTokenizerLoadType>(*m, "SPieceTokenizerLoadType", py::arithmetic())
.value("DE_SPIECE_TOKENIZER_LOAD_KFILE", SPieceTokenizerLoadType::kFile)
.value("DE_SPIECE_TOKENIZER_LOAD_KMODEL", SPieceTokenizerLoadType::kModel)
.export_values();
}));
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/util/random.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_sequential_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(ShardOperator, 0, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(
*m, "ShardOperator")
.def("add_child",
[](std::shared_ptr<mindrecord::ShardOperator> self,
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
}));
PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m,
"MindrecordDistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t, int64_t>());
}));
PYBIND_REGISTER(
ShardPkSample, 1, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
*m, "MindrecordPkSampler")
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
if (shuffle == true) {
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
GetSeed());
} else {
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
}
}));
}));
PYBIND_REGISTER(
ShardSample, 0, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
*m, "MindrecordSubsetRandomSampler")
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
}));
PYBIND_REGISTER(ShardSequentialSample, 0, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardSequentialSample>>(*m,
"MindrecordSequentialSampler")
.def(py::init([](int num_samples, int start_index) {
return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
}));
}));
PYBIND_REGISTER(
ShardShuffle, 1, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
*m, "MindrecordRandomSampler")
.def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
}));
}));
} // namespace dataset
} // namespace mindspore
......@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/de_pipeline.h"
#include "minddata/dataset/api/python/de_pipeline.h"
#include <algorithm>
#include <set>
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_register.h"
namespace mindspore {
namespace dataset {
PybindDefinedFunctionRegister &PybindDefinedFunctionRegister::GetSingleton() {
static PybindDefinedFunctionRegister instance;
return instance;
}
// This is where we externalize the C logic as python modules
PYBIND11_MODULE(_c_dataengine, m) {
m.doc() = "pybind11 for _c_dataengine";
auto all_fns = mindspore::dataset::PybindDefinedFunctionRegister::AllFunctions();
for (auto &item : all_fns) {
for (auto &func : item.second) {
func.second(&m);
}
}
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef API_PYBIND_API_H_
#define API_PYBIND_API_H_
#include <map>
#include <string>
#include <memory>
#include <functional>
#include <utility>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace mindspore {
namespace dataset {
#define THROW_IF_ERROR(s) \
do { \
Status rc = std::move(s); \
if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
} while (false)
using PybindDefineFunc = std::function<void(py::module *)>;
class PybindDefinedFunctionRegister {
public:
static void Register(const std::string &name, const uint8_t &priority, const PybindDefineFunc &fn) {
return GetSingleton().RegisterFn(name, priority, fn);
}
PybindDefinedFunctionRegister(const PybindDefinedFunctionRegister &) = delete;
PybindDefinedFunctionRegister &operator=(const PybindDefinedFunctionRegister &) = delete;
static std::map<uint8_t, std::map<std::string, PybindDefineFunc>> &AllFunctions() {
return GetSingleton().module_fns_;
}
std::map<uint8_t, std::map<std::string, PybindDefineFunc>> module_fns_;
protected:
PybindDefinedFunctionRegister() = default;
virtual ~PybindDefinedFunctionRegister() = default;
static PybindDefinedFunctionRegister &GetSingleton();
void RegisterFn(const std::string &name, const uint8_t &priority, const PybindDefineFunc &fn) {
module_fns_[priority][name] = fn;
}
};
class PybindDefineRegisterer {
public:
PybindDefineRegisterer(const std::string &name, const uint8_t &priority, const PybindDefineFunc &fn) {
PybindDefinedFunctionRegister::Register(name, priority, fn);
}
~PybindDefineRegisterer() = default;
};
#ifdef ENABLE_PYTHON
#define PYBIND_REGISTER(name, priority, define) PybindDefineRegisterer g_pybind_define_f_##name(#name, priority, define)
#endif
} // namespace dataset
} // namespace mindspore
#endif // API_PYBIND_API_H_
......@@ -4,16 +4,16 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if (ENABLE_PYTHON)
add_library(kernels OBJECT
compose_op.cc
random_apply_op.cc
random_choice_op.cc
data/compose_op.cc
data/random_apply_op.cc
data/random_choice_op.cc
py_func_op.cc
tensor_op.cc)
target_include_directories(kernels PRIVATE ${pybind11_INCLUDE_DIRS})
else()
add_library(kernels OBJECT
compose_op.cc
random_apply_op.cc
random_choice_op.cc
data/compose_op.cc
data/random_apply_op.cc
data/random_choice_op.cc
tensor_op.cc)
endif()
......@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/compose_op.h"
#include "minddata/dataset/kernels/data/compose_op.h"
#include <vector>
......
......@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/random_apply_op.h"
#include "minddata/dataset/kernels/data/random_apply_op.h"
#include <memory>
#include <vector>
......
......@@ -24,7 +24,7 @@
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/compose_op.h"
#include "minddata/dataset/kernels/data/compose_op.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/random.h"
......
......@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/random_choice_op.h"
#include "minddata/dataset/kernels/data/random_choice_op.h"
#include <memory>
#include <vector>
......
......@@ -25,7 +25,7 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/kernels/compose_op.h"
#include "minddata/dataset/kernels/data/compose_op.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
......
......@@ -29,7 +29,6 @@ namespace mindspore {
namespace dataset {
class RandomVerticalFlipWithBBoxOp : public TensorOp {
public:
// Default values, also used by python_bindings.cc
static const float kDefProbability;
// Constructor for RandomVerticalFlipWithBBoxOp
// @param probability: Probablity of Image flipping, 0.5 by default
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册