提交 c9de2af2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3116 [MD]Add sentencepiece Vocab and Tokenizer

Merge pull request !3116 from xulei/sentence_piece0715
if (WIN32)
set(sentencepiece_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 -Wno-unused-result -Wno-stringop-overflow -Wno-format-extra-args -Wno-format")
set(sentencepiece_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(sentencepiece
VER 0.1.92
LIBS sentencepiece sentencepiece_train
URL https://github.com/google/sentencepiece/archive/v0.1.92.tar.gz
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DSPM_USE_BUILTIN_PROTOBUF=ON
MD5 5dfd2241914b5598a68b2a8542ed8e91
)
else ()
set(sentencepiece_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 -Wno-unused-result -Wno-sign-compare")
set(sentencepiece_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(sentencepiece
VER 0.1.92
LIBS sentencepiece sentencepiece_train
URL https://github.com/google/sentencepiece/archive/v0.1.92.tar.gz
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DSPM_USE_BUILTIN_PROTOBUF=OFF -DSPM_ENABLE_SHARED=OFF -DPROTOBUF_INC=${protobuf_INC}
MD5 5dfd2241914b5598a68b2a8542ed8e91
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sentencepiece/sentencepiece.patch001
)
endif ()
include_directories(${sentencepiece_INC})
add_library(mindspore::sentencepiece ALIAS sentencepiece::sentencepiece)
add_library(mindspore::sentencepiece_train ALIAS sentencepiece::sentencepiece_train)
\ No newline at end of file
......@@ -75,6 +75,7 @@ if (ENABLE_MINDDATA)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sqlite.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tinyxml2.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/cppjieba.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sentencepiece.cmake)
endif()
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
......
......@@ -40,6 +40,7 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
set(jpeg_turbo_LIBPATH ${jpeg_turbo_LIBPATH}/../bin/)
set(sqlite_LIBPATH ${sqlite_LIBPATH}/../bin/)
set(tinyxml2_LIBPATH ${tinyxml2_LIBPATH}/../bin/)
set(sentencepiece_LIBPATH ${sentencepiece_LIBPATH}/../bin/)
else ()
set(INSTALL_LIB_DIR "lib")
endif ()
......@@ -91,6 +92,14 @@ if (ENABLE_MINDDATA)
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
file(GLOB_RECURSE SENTENCEPIECE_LIB_LIST
${sentencepiece_LIBPATH}/libsentencepiece*
)
install(
FILES ${SENTENCEPIECE_LIB_LIST}
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
message("icu4c does not support windows system temporarily")
else()
......
......@@ -128,7 +128,7 @@ else()
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
endif()
target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs
mindspore::opencv_imgproc mindspore::tinyxml2 ${ICU_LIB})
mindspore::opencv_imgproc mindspore::tinyxml2 mindspore::sentencepiece mindspore::sentencepiece_train ${ICU_LIB})
if (ENABLE_GPUQUE)
target_link_libraries(_c_dataengine PRIVATE gpu_queue
${CUDNN_PATH}/lib64/libcudnn.so
......
......@@ -87,7 +87,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp},
{kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}};
{kEpochCtrl, &DEPipeline::ParseEpochCtrlOp},
{kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
......@@ -1710,6 +1711,41 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<Datas
return Status::OK();
}
Status DEPipeline::ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
std::shared_ptr<BuildSentencePieceVocabOp::Builder> builder = std::make_shared<BuildSentencePieceVocabOp::Builder>();
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "vocab_size") {
builder->SetVocabSize(ToInt(value));
} else if (key == "character_coverage") {
(void)builder->SetCharacterCoverage(ToFloat(value));
} else if (key == "params") {
std::unordered_map<std::string, std::string> params;
for (auto param : py::reinterpret_borrow<py::dict>(value)) {
std::string param_key = py::reinterpret_borrow<py::str>(param.first);
if (param_key == "input" || param_key == "vocab_size" || param_key == "model_prefix" ||
param_key == "character_coverage" || param_key == "model_type") {
continue;
}
params[param_key] = py::reinterpret_borrow<py::str>(param.second);
}
(void)builder->SetParams(params);
} else if (key == "vocab") {
(void)builder->SetVocab(value.cast<std::shared_ptr<SentencePieceVocab>>());
} else if (key == "model_type") {
(void)builder->SetModelType(value.cast<SentencePieceModel>());
}
}
}
std::shared_ptr<BuildSentencePieceVocabOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
return Status::OK();
}
Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
......
......@@ -71,7 +71,8 @@ enum OpName {
kTextFile,
kBuildVocab,
kClue,
kEpochCtrl
kEpochCtrl,
kSentencePieceVocab,
};
// The C++ binder class that we expose to the python script.
......@@ -195,6 +196,8 @@ class DEPipeline {
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status StopSend();
Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom);
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
......
......@@ -88,7 +88,9 @@
#include "minddata/dataset/text/kernels/to_number_op.h"
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
#include "minddata/dataset/text/vocab.h"
#include "minddata/dataset/text/sentence_piece_vocab.h"
#include "minddata/dataset/util/random.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_operator.h"
......@@ -684,6 +686,15 @@ void bindTokenizerOps(py::module *m) {
(void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>(
*m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.")
.def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis"));
(void)py::class_<SentencePieceTokenizerOp, TensorOp, std::shared_ptr<SentencePieceTokenizerOp>>(
*m, "SentencePieceTokenizerOp", "Tokenize scalar token or 1-D tokens to tokens by sentence piece.")
.def(py::init<std::shared_ptr<SentencePieceVocab> &, const SPieceTokenizerLoadType, const SPieceTokenizerOutType>(),
py::arg("vocab"), py::arg("load_type") = SPieceTokenizerLoadType::kModel,
py::arg("out_type") = SPieceTokenizerOutType::kString)
.def(
py::init<const std::string &, const std::string &, const SPieceTokenizerLoadType, const SPieceTokenizerOutType>(),
py::arg("model_path"), py::arg("model_filename"), py::arg("load_type") = SPieceTokenizerLoadType::kFile,
py::arg("out_type") = SPieceTokenizerOutType::kString);
}
void bindDependIcuTokenizerOps(py::module *m) {
......@@ -839,6 +850,33 @@ void bindVocabObjects(py::module *m) {
THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v));
return v;
});
(void)py::class_<SentencePieceVocab, std::shared_ptr<SentencePieceVocab>>(*m, "SentencePieceVocab")
.def(py::init<>())
.def_static("from_file",
[](const py::list &paths, const int vocab_size, const float character_coverage,
const SentencePieceModel model_type, const py::dict &params) {
std::shared_ptr<SentencePieceVocab> v;
std::vector<std::string> path_list;
for (auto path : paths) {
path_list.emplace_back(py::str(path));
}
std::unordered_map<std::string, std::string> param_map;
for (auto param : params) {
std::string key = py::reinterpret_borrow<py::str>(param.first);
if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" ||
key == "model_type") {
continue;
}
param_map[key] = py::reinterpret_borrow<py::str>(param.second);
}
THROW_IF_ERROR(SentencePieceVocab::BuildFromFile(path_list, vocab_size, character_coverage,
model_type, param_map, &v));
return v;
})
.def_static("save_model",
[](const std::shared_ptr<SentencePieceVocab> *vocab, std::string path, std::string filename) {
THROW_IF_ERROR(SentencePieceVocab::SaveModel(vocab, path, filename));
});
}
void bindGraphData(py::module *m) {
......@@ -998,6 +1036,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("CIFAR100", OpName::kCifar100)
.value("RANDOMDATA", OpName::kRandomData)
.value("BUILDVOCAB", OpName::kBuildVocab)
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile)
.value("CLUE", OpName::kClue)
......@@ -1032,6 +1071,24 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("DE_BORDER_REFLECT", BorderType::kReflect)
.value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric)
.export_values();
(void)py::enum_<SentencePieceModel>(m, "SentencePieceModel", py::arithmetic())
.value("DE_SENTENCE_PIECE_UNIGRAM", SentencePieceModel::kUnigram)
.value("DE_SENTENCE_PIECE_BPE", SentencePieceModel::kBpe)
.value("DE_SENTENCE_PIECE_CHAR", SentencePieceModel::kChar)
.value("DE_SENTENCE_PIECE_WORD", SentencePieceModel::kWord)
.export_values();
(void)py::enum_<SPieceTokenizerOutType>(m, "SPieceTokenizerOutType", py::arithmetic())
.value("DE_SPIECE_TOKENIZER_OUTTYPE_KString", SPieceTokenizerOutType::kString)
.value("DE_SPIECE_TOKENIZER_OUTTYPE_KINT", SPieceTokenizerOutType::kInt)
.export_values();
(void)py::enum_<SPieceTokenizerLoadType>(m, "SPieceTokenizerLoadType", py::arithmetic())
.value("DE_SPIECE_TOKENIZER_LOAD_KFILE", SPieceTokenizerLoadType::kFile)
.value("DE_SPIECE_TOKENIZER_LOAD_KMODEL", SPieceTokenizerLoadType::kModel)
.export_values();
bindDEPipeline(&m);
bindTensor(&m);
bindTensorOps1(&m);
......
......@@ -33,6 +33,7 @@
#include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
#endif
#include "minddata/dataset/engine/datasetops/batch_op.h"
......
......@@ -32,6 +32,7 @@ if (ENABLE_PYTHON)
barrier_op.cc
filter_op.cc
build_vocab_op.cc
build_sentence_piece_vocab_op.cc
)
endif()
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
#include <iomanip>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
BuildSentencePieceVocabOp::BuildSentencePieceVocabOp(std::shared_ptr<SentencePieceVocab> vocab,
std::vector<std::string> col_names, uint32_t vocab_size,
float character_coverage, SentencePieceModel model_type,
const std::unordered_map<std::string, std::string> &params,
int32_t op_conn_size)
: PipelineOp(op_conn_size),
vocab_size_(vocab_size),
vocab_(vocab),
col_names_(col_names),
character_coverage_(character_coverage),
model_type_(model_type),
params_(params),
col_id_(0) {
sentence_queue_ = std::make_unique<Queue<TensorRow>>(op_conn_size);
read_done_ = false;
ret_status_ = Status::OK();
}
Status BuildSentencePieceVocabOp::operator()() {
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK(sentence_queue_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(
tree_->AllTasks()->CreateAsyncTask("sentenceTask", std::bind(&BuildSentencePieceVocabOp::SentenceThread, this)));
TaskManager::FindMe()->Post();
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
TensorRow new_row;
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
bool eoe_warning = false; // give out warning if receive more than 1 eoe
while (child_iterator_->eof_handled() == false) {
while (new_row.empty() == false) {
RETURN_IF_NOT_OK(sentence_queue_->EmplaceBack(new_row));
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
}
CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no op should be after from_dataset (repeat detected)");
eoe_warning = true;
}
// add empty tensorRow for quit
TensorRow empty_row = {};
sentence_queue_->EmplaceBack(empty_row);
return Status::OK();
}
Status BuildSentencePieceVocabOp::SentenceThread() {
TaskManager::FindMe()->Post();
if (col_names_.empty() == true) {
auto itr = column_name_id_map_.find("text");
CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(),
"'text' column doesn't exist when column name is empty");
col_id_ = itr->second;
} else {
auto itr = column_name_id_map_.find(col_names_[0]);
CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), col_names_[0] + "column doesn't exist");
col_id_ = itr->second;
}
std::unique_ptr<DatasetSentenceIterator> sentence_iter = std::make_unique<DatasetSentenceIterator>(this);
std::string model_proto;
sentencepiece::util::Status s_status =
sentencepiece::SentencePieceTrainer::Train(BuildParams(), sentence_iter.get(), &model_proto);
if (!s_status.ok()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, s_status.message());
} else {
if (vocab_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "sentencepiece vocab ptr must not be nullptr");
}
vocab_->set_model_proto(model_proto);
}
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
return Status::OK();
}
std::unordered_map<std::string, std::string> BuildSentencePieceVocabOp::BuildParams() {
std::unordered_map<std::string, std::string> ret_params;
ret_params["vocab_size"] = std::to_string(vocab_size_);
ret_params["character_coverage"] = std::to_string(character_coverage_);
if (model_type_ == SentencePieceModel::kBpe) {
ret_params["model_type"] = "BPE";
} else if (model_type_ == SentencePieceModel::kChar) {
ret_params["model_type"] = "CHAR";
} else if (model_type_ == SentencePieceModel::kWord) {
ret_params["model_type"] = "WORD";
} else {
ret_params["model_type"] = "UNIGRAM";
}
// filter some params that set by function param
// filter model_prefix that must be empty
for (auto param : params_) {
std::string key = param.first;
if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" ||
key == "model_type") {
continue;
}
ret_params[key] = param.second;
}
ret_params["model_prefix"] = "";
ret_params["minloglevel"] = "1";
return ret_params;
}
bool BuildSentencePieceVocabOp::Done() { return read_done_; }
void BuildSentencePieceVocabOp::Next(std::string *sentence) {
TensorRow new_row;
Status s = sentence_queue_->PopFront(&new_row);
if (s.IsError()) {
read_done_ = true;
ret_status_ = s;
return;
}
if (new_row.empty() == true) {
read_done_ = true;
ret_status_ = Status::OK();
return;
}
if (new_row[col_id_]->type().IsNumeric() || new_row[col_id_]->Rank() > 1) {
ret_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"for dataset only words on string columns or must bu scalar");
read_done_ = true;
return;
}
std::string_view sentence_v;
new_row[col_id_]->GetItemAt(&sentence_v, {});
std::string st{sentence_v};
*sentence = st;
ret_status_ = Status::OK();
}
// Pre-Visitor accept method for NodePass
Status BuildSentencePieceVocabOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<BuildSentencePieceVocabOp>(), modified);
}
Status BuildSentencePieceVocabOp::Builder::Build(std::shared_ptr<BuildSentencePieceVocabOp> *op) {
(*op) = std::make_shared<BuildSentencePieceVocabOp>(builder_vocab_, builder_col_names_, builder_vocab_size_,
builder_character_coverage_, builder_model_type_, builder_params_,
builder_connector_size_);
return Status::OK();
}
BuildSentencePieceVocabOp::Builder::Builder() {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_connector_size_ = cfg->op_connector_size();
}
BuildSentencePieceVocabOp::DatasetSentenceIterator::DatasetSentenceIterator(BuildSentencePieceVocabOp *s_p_vocab_ptr)
: s_p_vocab_ptr_(s_p_vocab_ptr) {}
bool BuildSentencePieceVocabOp::DatasetSentenceIterator::done() const {
if (s_p_vocab_ptr_ == nullptr) {
return true;
}
return s_p_vocab_ptr_->Done();
}
void BuildSentencePieceVocabOp::DatasetSentenceIterator::Next() {
if (s_p_vocab_ptr_ == nullptr) {
return;
}
s_p_vocab_ptr_->Next(&value_);
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_BUILD_SENTENCE_VOCAB_OP_H_
#define DATASET_ENGINE_DATASETOPS_BUILD_SENTENCE_VOCAB_OP_H_
#include <sentencepiece_trainer.h>
#include <sentencepiece_processor.h>
#include <vector>
#include <memory>
#include <unordered_map>
#include <string>
#include <utility>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/text/sentence_piece_vocab.h"
namespace mindspore {
namespace dataset {
namespace py = pybind11;
class BuildSentencePieceVocabOp : public PipelineOp {
public:
class Builder {
public:
Builder();
// Destructor.
~Builder() = default;
// Setter method
// @param uint32_t size
// @return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(uint32_t size) {
builder_connector_size_ = size;
return *this;
}
// Setter method
// @param uint32_t size
// @return Builder & reference to builder class object
Builder &SetVocabSize(uint32_t size) {
builder_vocab_size_ = size;
return *this;
}
// Setter method
// @param float charactor corverage - to determine the minimum symbols
// @return Builder & reference to builder class object
Builder &SetCharacterCoverage(float character_coverage) {
builder_character_coverage_ = character_coverage;
return *this;
}
// Setter method
// @param SentencePieceModel model_type - model algorithm
// @return Builder & reference to builder class object
Builder &SetModelType(SentencePieceModel model_type) {
builder_model_type_ = model_type;
return *this;
}
// Setter method
// @param std::unordered_map<std::string, std::string> params
// @return Builder & reference to builder class object
Builder &SetParams(std::unordered_map<std::string, std::string> params) {
builder_params_ = params;
return *this;
}
// Setter method
// @param std::shared_ptr<SentencePieceVocab> vocab
// @return Builder & reference to builder class object
Builder &SetVocab(std::shared_ptr<SentencePieceVocab> vocab) {
builder_vocab_ = vocab;
return *this;
}
// set columns names
// @param const std::vector<std::string> & col_names - name of columns to get words
// @return Builder & reference to builder class object
Builder &SetColumnNames(const std::vector<std::string> &col_names) {
builder_col_names_ = col_names;
return *this;
}
// The builder "build" method creates the final object.
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
// @return - The error code return
Status Build(std::shared_ptr<BuildSentencePieceVocabOp> *op);
private:
uint32_t builder_connector_size_;
uint32_t builder_vocab_size_;
float builder_character_coverage_;
SentencePieceModel builder_model_type_;
std::unordered_map<std::string, std::string> builder_params_;
std::vector<std::string> builder_col_names_;
std::shared_ptr<SentencePieceVocab> builder_vocab_;
};
public:
class DatasetSentenceIterator : public sentencepiece::SentenceIterator {
public:
explicit DatasetSentenceIterator(BuildSentencePieceVocabOp *s_p_vocab_ptr);
~DatasetSentenceIterator() {}
bool done() const override;
void Next() override;
const std::string &value() const override { return value_; }
sentencepiece::util::Status status() const override { return sentencepiece::util::Status(); }
private:
std::string value_;
BuildSentencePieceVocabOp *s_p_vocab_ptr_;
};
BuildSentencePieceVocabOp(std::shared_ptr<SentencePieceVocab> vocab, std::vector<std::string> col_names,
uint32_t vocab_size, float character_coverage, SentencePieceModel model_type,
const std::unordered_map<std::string, std::string> &params, int32_t op_conn_size);
~BuildSentencePieceVocabOp() = default;
// the thread for sentence train
Status SentenceThread();
Status EofReceived(int32_t) override { return Status::OK(); }
Status EoeReceived(int32_t) override { return Status::OK(); }
Status operator()() override;
// Getter
// @return the number of workers
int32_t num_producers() const override { return 1; }
// Getter
// @return the number of threads consuming from the previous Connector
int32_t num_consumers() const override { return 1; }
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildSentencePieceVocabOp"); }
// build the input params for sentence api
std::unordered_map<std::string, std::string> BuildParams();
bool Done();
void Next(std::string *sentence);
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
private:
bool read_done_;
Status ret_status_;
uint32_t vocab_size_;
float character_coverage_;
SentencePieceModel model_type_;
std::unordered_map<std::string, std::string> params_;
std::shared_ptr<SentencePieceVocab> vocab_;
std::vector<std::string> col_names_;
uint32_t col_id_;
std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1
std::unique_ptr<Queue<TensorRow>> sentence_queue_; // master thread assigns each worker TensorRow via this
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_BUILD_SENTENCE_VOCAB_OP_H_
......@@ -17,6 +17,7 @@
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
......@@ -261,5 +262,10 @@ Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -81,6 +81,8 @@ class EpochCtrlOp;
class BuildVocabOp;
class BuildSentencePieceVocabOp;
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
......@@ -206,6 +208,8 @@ class NodePass : public Pass {
virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified);
private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
......
......@@ -37,6 +37,16 @@ Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp
}
}
// Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
}
// Temporary code to prevent the injection of epoch control when cache op is present
// Remove this code in cache op phase 2
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
......
......@@ -45,6 +45,12 @@ class InjectionPass : public TreePass {
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override;
/// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override;
/// \brief Temporary code to prevent the injection of epoch control when cache op is present.
/// Remove this code in cache op phase 2
/// \param[in] node The node being visited
......
......@@ -136,6 +136,7 @@ constexpr char kRandomChoiceOp[] = "RandomChoiceOp";
constexpr char kRandomApplyOp[] = "RandomApplyOp";
constexpr char kComposeOp[] = "ComposeOp";
constexpr char kRandomSelectSubpolicyOp[] = "RandomSelectSubpolicyOp";
constexpr char kSentencepieceTokenizerOp[] = "SentencepieceTokenizerOp";
// data
constexpr char kConcatenateOp[] = "kConcatenateOp";
......
......@@ -4,6 +4,7 @@ file(GLOB _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(text OBJECT
vocab.cc
sentence_piece_vocab.cc
)
add_dependencies(text text-kernels)
\ No newline at end of file
......@@ -21,5 +21,6 @@ add_library(text-kernels OBJECT
wordpiece_tokenizer_op.cc
truncate_sequence_pair_op.cc
to_number_op.cc
sentence_piece_tokenizer_op.cc
${ICU_DEPEND_FILES}
)
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
#include <memory>
#include <vector>
#include "common/utils.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr<SentencePieceVocab> vocab,
const SPieceTokenizerLoadType load_type,
const SPieceTokenizerOutType out_type)
: vocab_(vocab), load_type_(load_type), out_type_(out_type) {}
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename,
const SPieceTokenizerLoadType load_type,
const SPieceTokenizerOutType out_type)
: load_type_(load_type), out_type_(out_type) {
(void)GetModelRealPath(model_path, model_filename);
}
Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor");
}
std::string_view sentence_v;
RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {}));
std::string sentence{sentence_v};
if (load_type_ == SPieceTokenizerLoadType::kFile) {
auto status = processor_.Load(file_path_);
if (!status.ok()) {
RETURN_STATUS_UNEXPECTED("load sentence piece model failed.");
}
} else {
RETURN_UNEXPECTED_IF_NULL(vocab_);
auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto());
if (!status.ok()) {
RETURN_STATUS_UNEXPECTED("sentence piece load model failed.");
}
}
if (out_type_ == SPieceTokenizerOutType::kString) {
std::vector<std::string> pieces;
auto status = processor_.Encode(sentence, &pieces);
if (!status.ok()) {
RETURN_STATUS_UNEXPECTED("sentence piece tokenizer error");
}
*output = std::make_unique<Tensor>(pieces, TensorShape({(dsize_t)pieces.size()}));
} else {
std::vector<int> ids;
auto status = processor_.Encode(sentence, &ids);
if (!status.ok()) {
RETURN_STATUS_UNEXPECTED("sentence piece tokenizer error");
}
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, ids, TensorShape({(dsize_t)ids.size()})));
}
return Status::OK();
}
Status SentencePieceTokenizerOp::GetModelRealPath(const std::string &model_path, const std::string &filename) {
char real_path[PATH_MAX] = {0};
if (file_path_.size() >= PATH_MAX) {
RETURN_STATUS_UNEXPECTED("sentence piece model path is invalid.");
}
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(real_path, common::SafeCStr(model_path), PATH_MAX) == nullptr) {
RETURN_STATUS_UNEXPECTED("sentence piece model path is invalid.");
}
#else
if (realpath(common::SafeCStr(model_path), real_path) == nullptr) {
RETURN_STATUS_UNEXPECTED("sentence piece model path is invalid.");
}
#endif
std::string abs_path = real_path;
file_path_ = (Path(abs_path) / Path(filename)).toString();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_SENTENCE_PIECE_TOKENIZER_OP_H
#define DATASET_SENTENCE_PIECE_TOKENIZER_OP_H
#include <sentencepiece_processor.h>
#include <string>
#include <iostream>
#include <memory>
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/text/sentence_piece_vocab.h"
namespace mindspore {
namespace dataset {
enum class SPieceTokenizerOutType { kString = 0, kInt = 1 };
enum class SPieceTokenizerLoadType { kFile = 0, kModel = 1 };
class SentencePieceTokenizerOp : public TensorOp {
public:
SentencePieceTokenizerOp(const std::shared_ptr<SentencePieceVocab> vocab, SPieceTokenizerLoadType load_type,
const SPieceTokenizerOutType out_type);
SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename,
const SPieceTokenizerLoadType load_type, const SPieceTokenizerOutType out_type);
~SentencePieceTokenizerOp() override = default;
Status GetModelRealPath(const std::string &model_path, const std::string &filename);
void Print(std::ostream &out) const override {
out << "SentencePieceTokenizerOp out_type = " << out_type_ << " load_type = " << load_type_;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kSentencepieceTokenizerOp; }
protected:
SPieceTokenizerOutType out_type_;
std::shared_ptr<SentencePieceVocab> vocab_;
std::string file_path_;
SPieceTokenizerLoadType load_type_;
sentencepiece::SentencePieceProcessor processor_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_SENTENCE_PIECE_TOKENIZER_OP_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/text/sentence_piece_vocab.h"
#include <sentencepiece_trainer.h>
#include <sentencepiece_processor.h>
#include <fstream>
#include "common/utils.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
SentencePieceVocab::SentencePieceVocab() : model_proto_("") {}
Status SentencePieceVocab::BuildFromFile(const std::vector<std::string> &path_list, const int vocab_size,
const float character_coverage, const SentencePieceModel model_type,
const std::unordered_map<std::string, std::string> &params,
std::shared_ptr<SentencePieceVocab> *vocab) {
std::unordered_map<std::string, std::string> unorder_map;
// the input of sentence is comma separated string
std::string input_str = "";
for (auto path : path_list) {
input_str += path;
input_str += ",";
}
input_str.pop_back();
unorder_map["input"] = input_str;
unorder_map["vocab_size"] = std::to_string(vocab_size);
unorder_map["model_prefix"] = "";
unorder_map["minloglevel"] = "1";
unorder_map["character_coverage"] = std::to_string(character_coverage);
if (model_type == SentencePieceModel::kWord) {
unorder_map["model_type"] = "WORD";
} else if (model_type == SentencePieceModel::kBpe) {
unorder_map["model_type"] = "BPE";
} else if (model_type == SentencePieceModel::kChar) {
unorder_map["model_type"] = "CHAR";
} else {
unorder_map["model_type"] = "UNIGRAM";
}
// filter some params that set by function param
// filter model_prefix that must be empty
for (auto param : params) {
std::string key = param.first;
if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" ||
key == "model_type") {
continue;
}
unorder_map[key] = param.second;
}
// set sentence lib's log
unorder_map["minloglevel"] = "1";
*vocab = std::make_shared<SentencePieceVocab>();
std::string model_proto;
sentencepiece::util::Status s_status = sentencepiece::SentencePieceTrainer::Train(unorder_map, nullptr, &model_proto);
if (!s_status.ok()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, s_status.message());
}
vocab->get()->set_model_proto(model_proto);
return Status::OK();
}
Status SentencePieceVocab::SaveModel(const std::shared_ptr<SentencePieceVocab> *vocab, std::string path,
std::string filename) {
char real_path[PATH_MAX] = {0};
if (path.size() >= PATH_MAX) {
RETURN_STATUS_UNEXPECTED("sentence model path is invalid.");
}
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
RETURN_STATUS_UNEXPECTED("sentence model path is invalid.");
}
#else
if (realpath(common::SafeCStr(path), real_path) == nullptr) {
RETURN_STATUS_UNEXPECTED("sentence model path is invalid.");
}
#endif
std::string abs_real_path = (Path(real_path) / Path(filename)).toString();
std::ofstream os_file(abs_real_path, std::ios::out);
(void)os_file.write(vocab->get()->model_proto().data(), vocab->get()->model_proto().size());
os_file.close();
return Status::OK();
}
const std::string &SentencePieceVocab::model_proto() { return model_proto_; }
void SentencePieceVocab::set_model_proto(const std::string model_proto) { model_proto_ = model_proto; }
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_TEXT_SENTENCE_PIECE_VOCAB_H_
#define DATASET_TEXT_SENTENCE_PIECE_VOCAB_H_
#include <string>
#include <memory>
#include <vector>
#include <unordered_map>
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
enum class SentencePieceModel { kUnigram = 0, kBpe = 1, kChar = 2, kWord = 3 };
class SentencePieceVocab {
public:
static Status BuildFromFile(const std::vector<std::string> &path_list, const int vocab_size,
const float character_coverage, const SentencePieceModel model_type,
const std::unordered_map<std::string, std::string> &params,
std::shared_ptr<SentencePieceVocab> *vocab);
static Status SaveModel(const std::shared_ptr<SentencePieceVocab> *vocab, std::string path, std::string filename);
SentencePieceVocab();
~SentencePieceVocab() = default;
const std::string &model_proto();
void set_model_proto(const std::string model_proto);
private:
std::string model_proto_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_TEXT_SENTENCE_PIECE_VOCAB_H_
......@@ -46,6 +46,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
try:
context = import_module("mindspore.context")
......@@ -909,6 +910,11 @@ class Dataset:
def build_vocab(self, vocab, columns, freq_range, top_k, special_tokens, special_first):
return BuildVocabDataset(self, vocab, columns, freq_range, top_k, special_tokens, special_first)
def build_sentencepiece_vocab(self, vocab, col_names, vocab_size,
character_coverage, model_type, params):
return BuildSentencePieceVocabDataset(self, vocab, col_names, vocab_size, character_coverage,
model_type, params)
def apply(self, apply_func):
"""
Apply a function in this dataset.
......@@ -5154,3 +5160,58 @@ class BuildVocabDataset(DatasetOp):
new_op.special_first = copy.deepcopy(self.special_first)
return new_op
class BuildSentencePieceVocabDataset(DatasetOp):
"""
Build a SentencePieceVocab from a dataset.
This function is not meant to be called directly by user. To build vocab, please use the function
text.SentencePieceVocab.from_dataset()
Args:
vocab(SentencePieceVocab): text.SentencePieceVocab object.
col_names(list): The list of the col name.
vocab_size(int): Vocabulary size, the type of uint32_t.
charater_coverage(float): Amount of characters covered by the model, good defaults are: 0.9995 for languages
with rich character set like Japanse or Chinese and 1.0 for other languages with small character set.
model_type(SentencePieceModel): Model type.Choose from unigram (default), bpe, char, or word.
The input sentence must be pretokenized when using word type.
params(dict): A dictionary with no incoming parameters.
"""
def __init__(self, input_dataset, vocab, col_names, vocab_size, character_coverage, model_type, params):
super().__init__()
self.vocab = vocab
self.col_names = col_names
self.vocab_size = vocab_size
self.children.append(input_dataset)
self.character_coverage = character_coverage
self.model_type = DE_C_INTER_SENTENCEPIECE_MODE[model_type]
self.params = params
input_dataset.parent.append(self)
def get_args(self):
args = super().get_args()
args["vocab"] = self.vocab
args["col_names"] = self.col_names
args["vocab_size"] = self.vocab_size
args["character_coverage"] = self.character_coverage
args["model_type"] = self.model_type
args["params"] = self.params
return args
def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
cls = self.__class__
new_op = cls.__new__(cls)
memodict[id(self)] = new_op
new_op.children = copy.deepcopy(self.children, memodict)
new_op.col_names = copy.deepcopy(self.col_names, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.vocab_size = copy.deepcopy(self.vocab_size, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.character_coverage = copy.deepcopy(self.character_coverage, memodict)
new_op.params = copy.deepcopy(self.params, memodict)
new_op.vocab = self.vocab
new_op.model_type = copy.deepcopy(self.model_type)
return new_op
......@@ -181,6 +181,8 @@ class Iterator:
op_type = OpName.TEXTFILE
elif isinstance(dataset, de.BuildVocabDataset):
op_type = OpName.BUILDVOCAB
elif isinstance(dataset, de.BuildSentencePieceVocabDataset):
op_type = OpName.SENTENCEPIECEVOCAB
elif isinstance(dataset, de.CLUEDataset):
op_type = OpName.CLUE
else:
......
......@@ -19,13 +19,16 @@ utils provides some general methods for nlp text processing.
"""
import platform
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \
ToNumber, SlidingWindow
from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm
ToNumber, SlidingWindow, SentencePieceTokenizer
from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm, SentencePieceVocab, SentencePieceModel, \
SPieceTokenizerOutType, SPieceTokenizerLoadType
__all__ = [
"Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram",
"to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber",
"PythonTokenizer", "SlidingWindow"
"PythonTokenizer", "SlidingWindow", "SentencePieceVocab", "SentencePieceTokenizer", "SPieceTokenizerOutType",
"SentencePieceModel", "SPieceTokenizerLoadType"
]
if platform.system().lower() != 'windows':
......
......@@ -50,7 +50,7 @@ import numpy as np
import mindspore._c_dataengine as cde
from .utils import JiebaMode, NormalizeForm, to_str
from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType
from .validators import check_lookup, check_jieba_add_dict, \
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\
check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\
......@@ -324,6 +324,36 @@ class WordpieceTokenizer(cde.WordpieceTokenizerOp):
super().__init__(self.vocab, self.suffix_indicator, self.max_bytes_per_token,
self.unknown_token, self.with_offsets)
DE_C_INTER_SENTENCEPIECE_LOADTYPE = {
SPieceTokenizerLoadType.FILE: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KFILE,
SPieceTokenizerLoadType.MODEL: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KMODEL
}
DE_C_INTER_SENTENCEPIECE_OUTTYPE = {
SPieceTokenizerOutType.STRING: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KString,
SPieceTokenizerOutType.INT: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KINT
}
class SentencePieceTokenizer(cde.SentencePieceTokenizerOp):
"""
Tokenize scalar token or 1-D tokens to tokens by sentencepiece.
Args:
mode(str or SentencePieceVocab): If the input parameter is a file, then it is of type string,
if the input parameter is a SentencePieceVocab object, then it is of type SentencePieceVocab.
out_type(str or int): The type of output.
"""
def __init__(self, mode, out_type):
self.out_type = out_type
if isinstance(mode, str):
model_path, model_filename = os.path.split(mode)
super().__init__(model_path, model_filename,
DE_C_INTER_SENTENCEPIECE_LOADTYPE[SPieceTokenizerLoadType.FILE],
DE_C_INTER_SENTENCEPIECE_OUTTYPE[out_type])
elif isinstance(mode, cde.SentencePieceVocab):
super().__init__(mode, DE_C_INTER_SENTENCEPIECE_LOADTYPE[SPieceTokenizerLoadType.MODEL],
DE_C_INTER_SENTENCEPIECE_OUTTYPE[out_type])
if platform.system().lower() != 'windows':
class WhitespaceTokenizer(cde.WhitespaceTokenizerOp):
......
......@@ -22,10 +22,11 @@ import copy
import numpy as np
import mindspore._c_dataengine as cde
from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset
from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \
check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model
__all__ = [
"Vocab", "to_str", "to_bytes"
"Vocab", "SentencePieceVocab", "to_str", "to_bytes"
]
......@@ -137,6 +138,71 @@ class Vocab(cde.Vocab):
return super().from_dict(word_dict)
class SentencePieceVocab(cde.SentencePieceVocab):
"""
SentencePiece obiect that is used to segmentate words
"""
@classmethod
@check_from_dataset_sentencepiece
def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params):
"""
Build a sentencepiece from a dataset
Args:
dataset(Dataset): Dataset to build sentencepiece.
col_names(list): The list of the col name.
vocab_size(int): Vocabulary size, the type of uint32_t.
character_coverage(float): Amount of characters covered by the model, good defaults are: 0.9995 for
languages. with rich character set like Japanse or Chinese and 1.0 for other languages with small
character set.
model_type(SentencePieceModel): Choose from unigram (default), bpe, char, or word. The input sentence
must be pretokenized when using word type.
params(dict): A dictionary with no incoming parameters.
Returns:
SentencePiece, SentencePiece object from dataset.
"""
vocab = SentencePieceVocab()
root = copy.deepcopy(dataset).build_sentencepiece_vocab(vocab, col_names, vocab_size, character_coverage,
model_type, params)
for d in root.create_dict_iterator():
if d is None:
raise ValueError("from_dataset should receive data other than None.")
return vocab
@classmethod
@check_from_file_sentencepiece
def from_file(cls, file_path, vocab_size, character_coverage, model_type, params):
"""
Build a SentencePiece object from a list of word.
Args:
file_path(list): Path to the file which contains the sentencepiece list.
vocab_size(int): Vocabulary size, the type of uint32_t.
character_coverage(float): Amount of characters covered by the model, good defaults are: 0.9995 for
languages. with rich character set like Japanse or Chinese and 1.0 for other languages with small
character set.
model_type(SentencePieceModel): Choose from unigram (default), bpe, char, or word. The input sentence
must be pretokenized when using word type.
params(dict): A dictionary with no incoming parameters.
"""
return super().from_file(file_path, vocab_size, character_coverage,
DE_C_INTER_SENTENCEPIECE_MODE[model_type], params)
@classmethod
@check_save_model
def save_model(cls, vocab, path, filename):
"""
Save model to filepath
Args:
vocab(SentencePieceVocab): A sentencepiece object.
path(str): Path to store model.
filename(str): The name of the file.
"""
return super().save_model(vocab, path, filename)
def to_str(array, encoding='utf8'):
"""
......@@ -188,3 +254,27 @@ class NormalizeForm(IntEnum):
NFKC = 2
NFD = 3
NFKD = 4
class SentencePieceModel(IntEnum):
"""An enumeration for SentencePieceModel, effective enumeration types are UNIGRAM, BPE, CHAR, WORD."""
UNIGRAM = 0
BPE = 1
CHAR = 2
WORD = 3
DE_C_INTER_SENTENCEPIECE_MODE = {
SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM,
SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE,
SentencePieceModel.CHAR: cde.SentencePieceModel.DE_SENTENCE_PIECE_CHAR,
SentencePieceModel.WORD: cde.SentencePieceModel.DE_SENTENCE_PIECE_WORD
}
class SPieceTokenizerOutType(IntEnum):
"""An enumeration for SPieceTokenizerOutType, effective enumeration types are STRING, INT."""
STRING = 0
INT = 1
class SPieceTokenizerLoadType(IntEnum):
"""An enumeration for SPieceTokenizerLoadType, effective enumeration types are FILE, MODEL."""
FILE = 0
MODEL = 1
......@@ -419,3 +419,81 @@ def check_python_tokenizer(method):
return method(self, *args, **kwargs)
return new_method
def check_from_dataset_sentencepiece(method):
"""A wrapper that wraps a parameter checker to the original function (from_dataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
if col_names is not None:
type_check(col_names, (list,), "col_names")
if vocab_size is not None:
check_uint32(vocab_size, "vocab_size")
if character_coverage is not None:
type_check(character_coverage, (float,), "character_coverage")
if model_type is not None:
from .utils import SentencePieceModel
type_check(model_type, (str, SentencePieceModel), "model_type")
if params is not None:
type_check(params, (dict,), "params")
return method(self, *args, **kwargs)
return new_method
def check_from_file_sentencepiece(method):
"""A wrapper that wraps a parameter checker to the original function (from_file)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[file_path, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
if file_path is not None:
type_check(file_path, (list,), "file_path")
if vocab_size is not None:
check_uint32(vocab_size, "vocab_size")
if character_coverage is not None:
type_check(character_coverage, (float,), "character_coverage")
if model_type is not None:
from .utils import SentencePieceModel
type_check(model_type, (str, SentencePieceModel), "model_type")
if params is not None:
type_check(params, (dict,), "params")
return method(self, *args, **kwargs)
return new_method
def check_save_model(method):
"""A wrapper that wraps a parameter checker to the original function (save_model)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[vocab, path, filename], _ = parse_user_args(method, *args, **kwargs)
if vocab is not None:
type_check(vocab, (cde.SentencePieceVocab,), "vocab")
if path is not None:
type_check(path, (str,), "path")
if filename is not None:
type_check(filename, (str,), "filename")
return method(self, *args, **kwargs)
return new_method
\ No newline at end of file
......@@ -94,6 +94,7 @@ SET(DE_UT_SRCS
tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc
epoch_ctrl_op_test.cc
sentence_piece_vocab_op_test.cc
)
add_executable(de_ut_tests ${DE_UT_SRCS})
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <string_view>
#include "common/common.h"
#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
#include "minddata/dataset/text/sentence_piece_vocab.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/util/status.h"
using namespace mindspore::dataset;
class MindDataTestSentencePieceVocabOp : public UT::DatasetOpTesting {
public:
void CheckEqual(const std::shared_ptr<Tensor> &o, const std::vector<dsize_t> &index, const std::string &expect) {
std::string_view str;
Status s = o->GetItemAt(&str, index);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(str, expect);
}
};
TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromDatasetFuntions) {
MS_LOG(INFO) << "Doing MindDataTestSentencePieceVocabOp TestSentencePieceFromDatasetFuntions.";
std::string dataset_path;
dataset_path = datasets_root_path_ + "/test_sentencepiece/botchan.txt";
auto tree = std::make_shared<ExecutionTree>();
std::shared_ptr<TextFileOp> file_op;
TextFileOp::Builder builder_file;
builder_file.SetTextFilesList({dataset_path}).SetRowsPerBuffer(1).SetNumWorkers(1).SetOpConnectorSize(2);
Status rc = builder_file.Build(&file_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(file_op);
ASSERT_TRUE(rc.IsOk());
std::shared_ptr<SentencePieceVocab> spm = std::make_unique<SentencePieceVocab>();
std::shared_ptr<BuildSentencePieceVocabOp> spv_op;
BuildSentencePieceVocabOp::Builder builder_spv;
std::vector<std::string> cols;
std::unordered_map<std::string, std::string> m_params;
builder_spv.SetVocab(spm)
.SetVocabSize(5000)
.SetColumnNames(cols)
.SetCharacterCoverage(0.9995)
.SetModelType(SentencePieceModel::kUnigram)
.SetParams(m_params)
.SetOpConnectorSize(2);
rc = builder_spv.Build(&spv_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(spv_op);
ASSERT_TRUE(rc.IsOk());
rc = spv_op->AddChild(file_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssignRoot(spv_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
while (!tensor_list.empty()) {
rc = di.FetchNextTensorRow(&tensor_list);
}
ASSERT_TRUE(rc.IsOk());
}
TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromFileFuntions) {
MS_LOG(INFO) << "Doing MindDataTestSentencePieceVocabOp TestSentencePieceFromFileFuntions.";
std::string dataset_path;
dataset_path = datasets_root_path_ + "/test_sentencepiece/botchan.txt";
std::vector<std::string> path_list;
path_list.emplace_back(dataset_path);
std::unordered_map<std::string, std::string> param_map;
std::shared_ptr<SentencePieceVocab> spm = std::make_unique<SentencePieceVocab>();
Status rc = SentencePieceVocab::BuildFromFile(path_list, 5000, 0.9995, SentencePieceModel::kUnigram, param_map, &spm);
ASSERT_TRUE(rc.IsOk());
}
TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceTokenizerFuntions) {
MS_LOG(INFO) << "Doing MindDataTestSentencePieceVocabOp TestSentencePieceTokenizerFuntions.";
std::string dataset_path;
dataset_path = datasets_root_path_ + "/test_sentencepiece/botchan.txt";
auto tree = std::make_shared<ExecutionTree>();
std::shared_ptr<TextFileOp> file_op;
TextFileOp::Builder builder_file;
builder_file.SetTextFilesList({dataset_path}).SetRowsPerBuffer(1).SetNumWorkers(1).SetOpConnectorSize(2);
Status rc = builder_file.Build(&file_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(file_op);
ASSERT_TRUE(rc.IsOk());
std::shared_ptr<SentencePieceVocab> spm = std::make_unique<SentencePieceVocab>();
std::shared_ptr<BuildSentencePieceVocabOp> spv_op;
BuildSentencePieceVocabOp::Builder builder_spv;
std::vector<std::string> cols;
std::unordered_map<std::string, std::string> m_params;
builder_spv.SetVocab(spm)
.SetVocabSize(5000)
.SetColumnNames(cols)
.SetCharacterCoverage(0.9995)
.SetModelType(SentencePieceModel::kUnigram)
.SetParams(m_params)
.SetOpConnectorSize(2);
rc = builder_spv.Build(&spv_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(spv_op);
ASSERT_TRUE(rc.IsOk());
rc = spv_op->AddChild(file_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssignRoot(spv_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
while (!tensor_list.empty()) {
rc = di.FetchNextTensorRow(&tensor_list);
}
std::shared_ptr<Tensor> output_tensor;
std::unique_ptr<SentencePieceTokenizerOp> op(new SentencePieceTokenizerOp(spm,
SPieceTokenizerLoadType::kModel, SPieceTokenizerOutType::kString));
std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>("I saw a girl with a telescope.");
Status s = op->Compute(input_tensor, &output_tensor);
std::vector<std::string> expect;
expect.push_back("▁I");
expect.push_back("▁sa");
expect.push_back("w");
expect.push_back("▁a");
expect.push_back("▁girl");
expect.push_back("▁with");
expect.push_back("▁a");
expect.push_back("▁te");
expect.push_back("les");
expect.push_back("co");
expect.push_back("pe");
expect.push_back(".");
ASSERT_TRUE(output_tensor->Size() == expect.size());
for (int i = 0; i < output_tensor->Size(); i++) {
std::string_view str;
output_tensor->GetItemAt(&str, {i});
std::string sentence{str};
ASSERT_TRUE(sentence == expect[i]);
}
}
\ No newline at end of file
I saw a girl with a telescope.
\ No newline at end of file
因为 它太大了无法显示 source diff 。你可以改为 查看blob
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset.text as text
import mindspore.dataset as ds
from mindspore.dataset.text import SentencePieceModel, to_str, SPieceTokenizerOutType
VOCAB_FILE = "../data/dataset/test_sentencepiece/botchan.txt"
DATA_FILE = "../data/dataset/testTokenizerData/sentencepiece_tokenizer.txt"
def test_from_vocab_to_str():
vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING)
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
dataset = dataset.map(operations=tokenizer)
expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.']
for i in dataset.create_dict_iterator():
ret = to_str(i["text"])
for key, value in enumerate(ret):
assert value == expect[key]
def test_from_vocab_to_int():
vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.INT)
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
dataset = dataset.map(operations=tokenizer)
expect = [6, 329, 183, 8, 945, 23, 8, 3783, 4382, 4641, 1405, 4]
for i in dataset.create_dict_iterator():
ret = i["text"]
for key, value in enumerate(ret):
assert value == expect[key]
def test_from_file_to_str():
vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
text.SentencePieceVocab.save_model(vocab, "./", "m.model")
tokenizer = text.SentencePieceTokenizer("./m.model", out_type=SPieceTokenizerOutType.STRING)
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
dataset = dataset.map(operations=tokenizer)
expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.']
for i in dataset.create_dict_iterator():
ret = to_str(i["text"])
for key, value in enumerate(ret):
assert value == expect[key]
def test_from_file_to_int():
vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
text.SentencePieceVocab.save_model(vocab, "./", "m.model")
tokenizer = text.SentencePieceTokenizer("./m.model", out_type=SPieceTokenizerOutType.INT)
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
dataset = dataset.map(operations=tokenizer)
expect = [6, 329, 183, 8, 945, 23, 8, 3783, 4382, 4641, 1405, 4]
for i in dataset.create_dict_iterator():
ret = i["text"]
for key, value in enumerate(ret):
assert value == expect[key]
def test_build_from_dataset():
data = ds.TextFileDataset(VOCAB_FILE, shuffle=False)
vocab = text.SentencePieceVocab.from_dataset(data, [""], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING)
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
dataset = dataset.map(operations=tokenizer)
expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.']
for i in dataset.create_dict_iterator():
ret = to_str(i["text"])
for key, value in enumerate(ret):
assert value == expect[key]
if __name__ == "__main__":
test_from_vocab_to_str()
test_from_vocab_to_int()
test_from_file_to_str()
test_from_file_to_int()
test_build_from_dataset()
diff -Npur sentencepiece-0.1.92/src/CMakeLists.txt sentencepiece-0.1.92_bak/src/CMakeLists.txt
--- sentencepiece-0.1.92/src/CMakeLists.txt 2020-06-08 16:25:01.000000000 +0800
+++ sentencepiece-0.1.92_bak/src/CMakeLists.txt 2020-07-02 17:42:33.306933546 +0800
@@ -11,6 +11,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.!
+add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
+
+
+function(protobuf_generate c_var h_var)
+ if(NOT ARGN)
+ message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files")
+ return()
+ endif()
+
+ set(${c_var})
+ set(${h_var})
+
+ find_program(PROTOC_EXE NAMES "protoc" PATHS ${PROTOBUF_INC}/../bin NO_DEFAULT_PATH)
+
+ foreach(file ${ARGN})
+ get_filename_component(abs_file ${file} ABSOLUTE)
+ get_filename_component(file_name ${file} NAME_WE)
+ get_filename_component(file_dir ${abs_file} PATH)
+ file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir})
+
+ list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${file_name}.pb.cc")
+ list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${file_name}.pb.h")
+
+ add_custom_command(
+ OUTPUT "${CMAKE_BINARY_DIR}/${file_name}.pb.cc"
+ "${CMAKE_BINARY_DIR}/${file_name}.pb.h"
+ WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
+ COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}"
+ COMMAND ${PROTOC_EXE} -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR} ${abs_file}
+ DEPENDS ${PROTOC_EXE} ${abs_file}
+ COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM)
+ endforeach()
+
+ set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE)
+ set(${c_var} ${${c_var}} PARENT_SCOPE)
+ set(${h_var} ${${h_var}} PARENT_SCOPE)
+
+endfunction()
+
+
if (SPM_USE_BUILTIN_PROTOBUF)
set(SPM_PROTO_HDRS builtin_pb/sentencepiece.pb.h)
@@ -52,12 +92,9 @@ if (SPM_USE_BUILTIN_PROTOBUF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite)
include_directories(builtin_pb)
else()
- find_package(Protobuf REQUIRED)
- include_directories(${Protobuf_INCLUDE_DIRS})
- protobuf_generate_cpp(SPM_PROTO_SRCS SPM_PROTO_HDRS sentencepiece.proto)
- protobuf_generate_cpp(SPM_MODEL_PROTO_SRCS SPM_MODEL_PROTO_HDRS sentencepiece_model.proto)
- set(PROTOBUF_LITE_SRCS "")
- include_directories(${PROTOBUF_INCLUDE_DIR})
+ include_directories(${PROTOBUF_INC})
+ protobuf_generate(SPM_PROTO_SRCS SPM_PROTO_HDRS sentencepiece.proto)
+ protobuf_generate(SPM_MODEL_PROTO_SRCS SPM_MODEL_PROTO_HDRS sentencepiece_model.proto)
endif()
include_directories(${CMAKE_CURRENT_BINARY_DIR})
@@ -191,11 +228,13 @@ endif()
add_library(sentencepiece-static STATIC ${SPM_SRCS})
add_library(sentencepiece_train-static STATIC ${SPM_TRAIN_SRCS})
-target_link_libraries(sentencepiece-static INTERFACE ${SPM_LIBS})
+find_library(PROTO_LIB NAMES "libprotobuf.a" PATHS ${PROTOBUF_INC}/../lib NO_DEFAULT_PATH)
+
+target_link_libraries(sentencepiece-static INTERFACE ${PROTO_LIB} ${SPM_LIBS})
target_link_libraries(sentencepiece_train-static INTERFACE sentencepiece-static ${SPM_LIBS})
if (SPM_ENABLE_SHARED)
- target_link_libraries(sentencepiece ${SPM_LIBS})
+ target_link_libraries(sentencepiece ${SPM_LIBS} ${PROTO_LIB})
target_link_libraries(sentencepiece_train ${SPM_LIBS} sentencepiece)
set(SPM_INSTALLTARGETS sentencepiece sentencepiece_train sentencepiece-static sentencepiece_train-static)
set_target_properties(sentencepiece sentencepiece_train PROPERTIES SOVERSION 0 VERSION 0.0.0)
@@ -265,7 +304,7 @@ install(TARGETS ${SPM_INSTALLTARGETS}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
-install(FILES sentencepiece_trainer.h sentencepiece_processor.h
+install(FILES sentencepiece_trainer.h sentencepiece_processor.h "${CMAKE_BINARY_DIR}/sentencepiece_model.pb.h"
DESTINATION ${CMAKE_INSTALL_INCDIR})
file(TO_NATIVE_PATH "${PROJECT_SOURCE_DIR}/data" data_dir)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册