提交 bca2b1a0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1236 [Data]Add NLP JiebaTokenizer operation

Merge pull request !1236 from xulei/cppjieba0518
...@@ -3042,6 +3042,16 @@ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS", AND ...@@ -3042,6 +3042,16 @@ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS", AND
Why Three Licenses? Why Three Licenses?
The zlib License could have been used instead of the Modified (3-clause) BSD License, and since the IJG License effectively subsumes the distribution conditions of the zlib License, this would have effectively placed libjpeg-turbo binary distributions under the IJG License. However, the IJG License specifically refers to the Independent JPEG Group and does not extend attribution and endorsement protections to other entities. Thus, it was desirable to choose a license that granted us the same protections for new code that were granted to the IJG for code derived from their software. The zlib License could have been used instead of the Modified (3-clause) BSD License, and since the IJG License effectively subsumes the distribution conditions of the zlib License, this would have effectively placed libjpeg-turbo binary distributions under the IJG License. However, the IJG License specifically refers to the Independent JPEG Group and does not extend attribution and endorsement protections to other entities. Thus, it was desirable to choose a license that granted us the same protections for new code that were granted to the IJG for code derived from their software.
Software: cppjieba 5.0.3
Copyright notice:
Copyright 2005, Google Inc.
Copyright 2008, Google Inc.
Copyright 2007, Google Inc.
Copyright 2008 Google Inc.
Copyright 2006, Google Inc.
Copyright 2003 Google Inc.
Copyright 2009 Google Inc.
Copyright (C) 1991-2, RSA Data Security, Inc. Created 1991. All
Software: opencv 4.2.0 Software: opencv 4.2.0
Copyright notice: Copyright notice:
......
set(cppjieba_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(cppjieba_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(cppjieba
VER 5.0.3
HEAD_ONLY ./
URL https://codeload.github.com/yanyiwu/cppjieba/zip/v5.0.3
MD5 0dfef44bd32328c221f128b401e1a45c
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/cppjieba/cppjieba.patch001)
include_directories(${cppjieba_INC}include)
include_directories(${cppjieba_INC}deps)
add_library(mindspore::cppjieba ALIAS cppjieba)
\ No newline at end of file
...@@ -57,6 +57,7 @@ if (ENABLE_MINDDATA) ...@@ -57,6 +57,7 @@ if (ENABLE_MINDDATA)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/opencv.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/opencv.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sqlite.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sqlite.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tinyxml2.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tinyxml2.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/cppjieba.cmake)
endif() endif()
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
......
...@@ -61,6 +61,7 @@ set(submodules ...@@ -61,6 +61,7 @@ set(submodules
$<TARGET_OBJECTS:kernels> $<TARGET_OBJECTS:kernels>
$<TARGET_OBJECTS:kernels-image> $<TARGET_OBJECTS:kernels-image>
$<TARGET_OBJECTS:kernels-data> $<TARGET_OBJECTS:kernels-data>
$<TARGET_OBJECTS:kernels-nlp>
$<TARGET_OBJECTS:APItoPython> $<TARGET_OBJECTS:APItoPython>
$<TARGET_OBJECTS:engine-datasetops-source> $<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler> $<TARGET_OBJECTS:engine-datasetops-source-sampler>
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "dataset/kernels/image/resize_op.h" #include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h" #include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h" #include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/text/jieba_tokenizer_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h" #include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h" #include "dataset/engine/datasetops/source/io_block.h"
...@@ -406,6 +407,14 @@ void bindTensorOps4(py::module *m) { ...@@ -406,6 +407,14 @@ void bindTensorOps4(py::module *m) {
py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
} }
void bindTensorOps6(py::module *m) {
(void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "")
.def(py::init<const std::string, std::string, JiebaMode>(), py::arg("hmm_path"), py::arg("mp_path"),
py::arg("mode") = JiebaMode::kMix)
.def("add_word",
[](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); });
}
void bindSamplerOps(py::module *m) { void bindSamplerOps(py::module *m) {
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler") (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
.def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
...@@ -500,6 +509,12 @@ PYBIND11_MODULE(_c_dataengine, m) { ...@@ -500,6 +509,12 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("CELEBA", OpName::kCelebA) .value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile); .value("TEXTFILE", OpName::kTextFile);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_INTER_JIEBA_MIX", JiebaMode::kMix)
.value("DE_INTER_JIEBA_MP", JiebaMode::kMp)
.value("DE_INTER_JIEBA_HMM", JiebaMode::kHmm)
.export_values();
(void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic()) (void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
.value("DE_INTER_LINEAR", InterpolationMode::kLinear) .value("DE_INTER_LINEAR", InterpolationMode::kLinear)
.value("DE_INTER_CUBIC", InterpolationMode::kCubic) .value("DE_INTER_CUBIC", InterpolationMode::kCubic)
...@@ -519,6 +534,7 @@ PYBIND11_MODULE(_c_dataengine, m) { ...@@ -519,6 +534,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
bindTensorOps2(&m); bindTensorOps2(&m);
bindTensorOps3(&m); bindTensorOps3(&m);
bindTensorOps4(&m); bindTensorOps4(&m);
bindTensorOps6(&m);
bindSamplerOps(&m); bindSamplerOps(&m);
bindDatasetOps(&m); bindDatasetOps(&m);
bindInfoObjects(&m); bindInfoObjects(&m);
......
...@@ -2,6 +2,7 @@ add_subdirectory(image) ...@@ -2,6 +2,7 @@ add_subdirectory(image)
add_subdirectory(data) add_subdirectory(data)
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_subdirectory(text)
add_library(kernels OBJECT add_library(kernels OBJECT
py_func_op.cc py_func_op.cc
tensor_op.cc) tensor_op.cc)
......
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(kernels-nlp OBJECT
jieba_tokenizer_op.cc
)
\ No newline at end of file
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/text/jieba_tokenizer_op.h"
#include <vector>
#include <memory>
#include <string>
#include "dataset/util/path.h"
namespace mindspore {
namespace dataset {
JiebaTokenizerOp::JiebaTokenizerOp(const std::string &hmm_path, const std::string &dict_path, JiebaMode mode)
: jieba_mode_(mode), hmm_model_path_(hmm_path), mp_dict_path_(dict_path) {
jieba_parser_ = std::make_unique<cppjieba::Jieba>(mp_dict_path_, hmm_model_path_, "");
}
Status JiebaTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
RETURN_UNEXPECTED_IF_NULL(jieba_parser_);
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor");
}
std::string_view sentence_v;
RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {}));
std::string sentence{sentence_v};
std::vector<std::string> words;
if (sentence == "") {
words.push_back("");
} else {
if (jieba_mode_ == JiebaMode::kMp) {
jieba_parser_->CutSmall(sentence, words, MAX_WORD_LENGTH);
} else if (jieba_mode_ == JiebaMode::kHmm) {
jieba_parser_->CutHMM(sentence, words);
} else { // Mix
jieba_parser_->Cut(sentence, words, true);
}
}
*output = std::make_shared<Tensor>(words, TensorShape({(dsize_t)words.size()}));
return Status::OK();
}
Status JiebaTokenizerOp::AddWord(const std::string &word, int freq) {
RETURN_UNEXPECTED_IF_NULL(jieba_parser_);
if (jieba_parser_->InsertUserWord(word, freq, "") == false) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "add word error");
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_NLP_JIEBA_OP_H_
#define DATASET_ENGINE_NLP_JIEBA_OP_H_
#include <string>
#include <memory>
#include "cppjieba/Jieba.hpp"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
enum class JiebaMode { kMix = 0, kMp = 1, kHmm = 2 };
class JiebaTokenizerOp : public TensorOp {
public:
// deffault constant for Jieba MPSegment algorithm.
static constexpr size_t MAX_WORD_LENGTH = 512;
// Constructor for JiebaTokenizerOp.
// @param hmm_path HMM model file.
// @param mp_path MP model file.
// @mode tokenization mode [Default "MIX"], "MP" model will tokenize with MPSegment algorithm, "HMM" mode will
// tokenize with Hiddel Markov Model Segment algorithm, "MIx" model will tokenize with a mix of MPSegment and
// HMMSegment algorithm.
JiebaTokenizerOp(const std::string &hmm_path, const std::string &mp_path, JiebaMode mode = JiebaMode::kMix);
~JiebaTokenizerOp() override = default;
void Print(std::ostream &out) const override {
out << "JiebaTokenizerOp: " << jieba_mode_ << "hmm_model_path_ " << hmm_model_path_ << "mp_dict_path_"
<< mp_dict_path_;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
// @word the word to be added to the JiebaTokenizer.
// @freq [Default 0] the frequency fo the word to be added.
// @tag [Default ""] the tag of the word to be added.
Status AddWord(const std::string &word, int freq = 0);
protected:
std::string hmm_model_path_;
std::string mp_dict_path_;
std::unique_ptr<cppjieba::Jieba> jieba_parser_;
JiebaMode jieba_mode_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_NLP_JIEBA_OP_H_
...@@ -17,4 +17,4 @@ c_transforms and py_transforms. C_transforms is a high performance ...@@ -17,4 +17,4 @@ c_transforms and py_transforms. C_transforms is a high performance
image augmentation module which is developed with c++ opencv. Py_transforms image augmentation module which is developed with c++ opencv. Py_transforms
provide more kinds of image augmentations which is developed with python PIL. provide more kinds of image augmentations which is developed with python PIL.
""" """
from .utils import as_text from .utils import as_text, JiebaMode
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module c_transforms provides common nlp operations.
"""
import os
import re
import mindspore._c_dataengine as cde
from .utils import JiebaMode
from .validators import check_jieba_add_dict, check_jieba_add_word, check_jieba_init
DE_C_INTER_JIEBA_MODE = {
JiebaMode.MIX: cde.JiebaMode.DE_INTER_JIEBA_MIX,
JiebaMode.MP: cde.JiebaMode.DE_INTER_JIEBA_MP,
JiebaMode.HMM: cde.JiebaMode.DE_INTER_JIEBA_HMM
}
class JiebaTokenizer(cde.JiebaTokenizerOp):
"""
Tokenize Chinese string into words based on dictionary.
Args:
mode (Enum): [Default "MIX"], "MP" model will tokenize with MPSegment algorithm, "HMM" mode will
tokenize with Hiddel Markov Model Segment algorithm, "MIX" model will tokenize with a mix of MPSegment and
HMMSegment algorithm.
"""
@check_jieba_init
def __init__(self, hmm_path, mp_path, mode=JiebaMode.MIX):
self.mode = mode
self.__check_path__(hmm_path)
self.__check_path__(mp_path)
super().__init__(hmm_path, mp_path,
DE_C_INTER_JIEBA_MODE[mode])
@check_jieba_add_word
def add_word(self, word, freq=None):
"""
Add user defined word to JiebaTokenizer's dictionary
Args:
word(required, string): The word to be added to the JiebaTokenizer instance.
The added word will not be written into the built-in dictionary on disk.
freq(optional, int): The frequency of the word to be added,
The higher the frequency, the better change the word will be tokenized(default None,
use default frequency)
"""
if freq is None:
super().add_word(word, 0)
else:
super().add_word(word, freq)
@check_jieba_add_dict
def add_dict(self, user_dict):
"""
Add user defined word to JiebaTokenizer's dictionary
Args:
user_dict(path/dict):Dictionary to be added, file path or Python dictionary,
Python Dict format is {word1:freq1, word2:freq2,...}
Jieba dictionary format : word(required), freq(optional), such as:
word1 freq1
word2
word3 freq3
"""
if isinstance(user_dict, str):
self.__add_dict_py_file(user_dict)
elif isinstance(user_dict, dict):
for k, v in user_dict.items():
self.add_word(k, v)
else:
raise ValueError("the type of user_dict must str or dict")
def __add_dict_py_file(self, file_path):
"""Add user defined word by file"""
words_list = self.__parser_file(file_path)
for data in words_list:
if data[1] is None:
freq = 0
else:
freq = int(data[1])
self.add_word(data[0], freq)
def __parser_file(self, file_path):
"""parser user defined word by file"""
if not os.path.exists(file_path):
raise ValueError(
"user dict file {} is not exist".format(file_path))
file_dict = open(file_path)
data_re = re.compile('^(.+?)( [0-9]+)?$', re.U)
words_list = []
for item in file_dict:
data = item.strip()
if not isinstance(data, str):
data = self.__decode(data)
words = data_re.match(data).groups()
if len(words) != 2:
raise ValueError(
"user dict file {} format error".format(file_path))
words_list.append(words)
return words_list
def __decode(self, data):
"""decode the dict file to utf8"""
try:
data = data.decode('utf-8')
except UnicodeDecodeError:
raise ValueError("user dict file must utf8")
return data.lstrip('\ufeff')
def __check_path__(self, model_path):
"""check model path"""
if not os.path.exists(model_path):
raise ValueError(
" jieba mode file {} is not exist".format(model_path))
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
""" """
Some basic function for nlp Some basic function for nlp
""" """
from enum import IntEnum
import numpy as np import numpy as np
def as_text(array, encoding='utf8'): def as_text(array, encoding='utf8'):
""" """
Convert data of array to unicode. Convert data of array to unicode.
...@@ -31,5 +33,13 @@ def as_text(array, encoding='utf8'): ...@@ -31,5 +33,13 @@ def as_text(array, encoding='utf8'):
if not isinstance(array, np.ndarray): if not isinstance(array, np.ndarray):
raise ValueError('input should be a numpy array') raise ValueError('input should be a numpy array')
byte_array = bytearray(list(array)) def decode(x):
return byte_array.decode(encoding) return x.decode(encoding)
decode = np.vectorize(decode)
return decode(array)
class JiebaMode(IntEnum):
MIX = 0
MP = 1
HMM = 2
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Validators for TensorOps.
"""
from functools import wraps
from ...transforms.validators import check_uint32
def check_jieba_init(method):
"""Wrapper method to check the parameters of jieba add word."""
@wraps(method)
def new_method(self, *args, **kwargs):
hmm_path, mp_path, model = (list(args) + 3 * [None])[:3]
if "hmm_path" in kwargs:
hmm_path = kwargs.get("hmm_path")
if "mp_path" in kwargs:
mp_path = kwargs.get("mp_path")
if hmm_path is None:
raise ValueError(
"the dict of HMMSegment in cppjieba is not provided")
kwargs["hmm_path"] = hmm_path
if mp_path is None:
raise ValueError(
"the dict of MPSegment in cppjieba is not provided")
kwargs["mp_path"] = mp_path
if model is not None:
kwargs["model"] = model
return method(self, **kwargs)
return new_method
def check_jieba_add_word(method):
"""Wrapper method to check the parameters of jieba add word."""
@wraps(method)
def new_method(self, *args, **kwargs):
word, freq = (list(args) + 2 * [None])[:2]
if "word" in kwargs:
word = kwargs.get("word")
if "freq" in kwargs:
freq = kwargs.get("freq")
if word is None:
raise ValueError("word is not provided")
kwargs["word"] = word
if freq is not None:
check_uint32(freq)
kwargs["freq"] = freq
return method(self, **kwargs)
return new_method
def check_jieba_add_dict(method):
"""Wrapper method to check the parameters of add dict"""
@wraps(method)
def new_method(self, *args, **kwargs):
user_dict = (list(args) + [None])[0]
if "user_dict" in kwargs:
user_dict = kwargs.get("user_dict")
if user_dict is None:
raise ValueError("user_dict is not provided")
kwargs["user_dict"] = user_dict
return method(self, **kwargs)
return new_method
...@@ -68,6 +68,7 @@ SET(DE_UT_SRCS ...@@ -68,6 +68,7 @@ SET(DE_UT_SRCS
text_file_op_test.cc text_file_op_test.cc
filter_op_test.cc filter_op_test.cc
concat_op_test.cc concat_op_test.cc
jieba_tokenizer_op_test.cc
) )
add_executable(de_ut_tests ${DE_UT_SRCS}) add_executable(de_ut_tests ${DE_UT_SRCS})
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <string_view>
#include "common/common.h"
#include "dataset/kernels/text/jieba_tokenizer_op.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
class MindDataTestJiebaTokenizerOp : public UT::DatasetOpTesting {
public:
void CheckEqual(const std::shared_ptr<Tensor> &o, const std::vector<dsize_t> &index, const std::string &expect) {
std::string_view str;
Status s = o->GetItemAt(&str, index);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(str, expect);
}
};
TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opFuntions) {
MS_LOG(INFO) << "Doing MindDataTestJiebaTokenizerOp TestJieba_opFuntions.";
std::string dataset_path = datasets_root_path_ + "/jiebadict";
std::string hmm_path = dataset_path + "/hmm_model.utf8";
std::string mp_path = dataset_path + "/jieba.dict.utf8";
std::shared_ptr<Tensor> output_tensor;
std::unique_ptr<JiebaTokenizerOp> op(new JiebaTokenizerOp(hmm_path, mp_path));
std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>("今天天气太好了我们一起去外面玩吧");
Status s = op->Compute(input_tensor, &output_tensor);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(output_tensor->Rank(), 1);
EXPECT_EQ(output_tensor->Size(), 7);
CheckEqual(output_tensor, {0}, "今天天气");
CheckEqual(output_tensor, {1}, "太好了");
CheckEqual(output_tensor, {2}, "我们");
CheckEqual(output_tensor, {3}, "一起");
CheckEqual(output_tensor, {4}, "去");
CheckEqual(output_tensor, {5}, "外面");
CheckEqual(output_tensor, {6}, "玩吧");
}
TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opAdd) {
MS_LOG(INFO) << "Doing MindDataTestJiebaTokenizerOp TestJieba_opAdd.";
std::string dataset_path = datasets_root_path_ + "/jiebadict";
std::string hmm_path = dataset_path + "/hmm_model.utf8";
std::string mp_path = dataset_path + "/jieba.dict.utf8";
std::shared_ptr<Tensor> output_tensor;
std::unique_ptr<JiebaTokenizerOp> op(new JiebaTokenizerOp(hmm_path, mp_path));
op->AddWord("男默女泪");
std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>("男默女泪");
Status s = op->Compute(input_tensor, &output_tensor);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(output_tensor->Rank(), 1);
EXPECT_EQ(output_tensor->Size(), 1);
CheckEqual(output_tensor, {0}, "男默女泪");
}
TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opEmpty) {
MS_LOG(INFO) << "Doing MindDataTestJiebaTokenizerOp TestJieba_opEmpty.";
std::string dataset_path = datasets_root_path_ + "/jiebadict";
std::string hmm_path = dataset_path + "/hmm_model.utf8";
std::string mp_path = dataset_path + "/jieba.dict.utf8";
std::shared_ptr<Tensor> output_tensor;
std::unique_ptr<JiebaTokenizerOp> op(new JiebaTokenizerOp(hmm_path, mp_path));
op->AddWord("男默女泪");
std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>("");
Status s = op->Compute(input_tensor, &output_tensor);
EXPECT_TRUE(s.IsOk());
EXPECT_EQ(output_tensor->Rank(), 1);
EXPECT_EQ(output_tensor->Size(), 1);
CheckEqual(output_tensor, {0}, "");
}
\ No newline at end of file
此差异已折叠。
此差异已折叠。
今天天气太好了我们一起去外面玩吧
\ No newline at end of file
男默女泪市长江大桥
\ No newline at end of file
江州市长江大桥参加了长江大桥的通车仪式
\ No newline at end of file
天空好蓝
美丽心灵 10
男默女泪 3
\ No newline at end of file
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.nlp.utils as nlp import mindspore.dataset.transforms.text.utils as nlp
from mindspore import log as logger from mindspore import log as logger
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
...@@ -42,7 +42,8 @@ def test_textline_dataset_totext(): ...@@ -42,7 +42,8 @@ def test_textline_dataset_totext():
ds.config.set_num_parallel_workers(4) ds.config.set_num_parallel_workers(4)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
count = 0 count = 0
line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."] line = ["This is a text file.", "Another file.",
"Be happy every day.", "End of file.", "Good luck to everyone."]
for i in data.create_dict_iterator(): for i in data.create_dict_iterator():
str = i["text"].item().decode("utf8") str = i["text"].item().decode("utf8")
assert (str == line[count]) assert (str == line[count])
......
...@@ -24,7 +24,7 @@ def test_flat_map_1(): ...@@ -24,7 +24,7 @@ def test_flat_map_1():
''' '''
DATA_FILE records the path of image folders, load the images from them. DATA_FILE records the path of image folders, load the images from them.
''' '''
import mindspore.dataset.transforms.nlp.utils as nlp import mindspore.dataset.transforms.text.utils as nlp
def flat_map_func(x): def flat_map_func(x):
data_dir = x[0].item().decode('utf8') data_dir = x[0].item().decode('utf8')
...@@ -45,7 +45,7 @@ def test_flat_map_2(): ...@@ -45,7 +45,7 @@ def test_flat_map_2():
''' '''
Flatten 3D structure data Flatten 3D structure data
''' '''
import mindspore.dataset.transforms.nlp.utils as nlp import mindspore.dataset.transforms.text.utils as nlp
def flat_map_func_1(x): def flat_map_func_1(x):
data_dir = x[0].item().decode('utf8') data_dir = x[0].item().decode('utf8')
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.transforms.text.c_transforms import JiebaTokenizer
from mindspore.dataset.transforms.text.utils import JiebaMode, as_text
DATA_FILE = "../data/dataset/testJiebaDataset/3.txt"
DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*"
HMM_FILE = "../data/dataset/jiebadict/hmm_model.utf8"
MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8"
def test_jieba_1():
"""Test jieba tokenizer with MP mode"""
data = ds.TextFileDataset(DATA_FILE)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=1)
expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧']
ret = []
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def test_jieba_1_1():
"""Test jieba tokenizer with HMM mode"""
data = ds.TextFileDataset(DATA_FILE)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.HMM)
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=1)
expect = ['今天', '天气', '太', '好', '了', '我们', '一起', '去', '外面', '玩', '吧']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def test_jieba_1_2():
"""Test jieba tokenizer with HMM MIX"""
data = ds.TextFileDataset(DATA_FILE)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MIX)
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=1)
expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def test_jieba_2():
"""Test add_word"""
DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt"
data = ds.TextFileDataset(DATA_FILE4)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op.add_word("男默女泪")
expect = ['男默女泪', '市', '长江大桥']
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=2)
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def test_jieba_2_1():
"""Test add_word with freq"""
DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt"
data = ds.TextFileDataset(DATA_FILE4)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op.add_word("男默女泪", 10)
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=2)
expect = ['男默女泪', '市', '长江大桥']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def test_jieba_2_2():
"""Test add_word with invalid None Input"""
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
try:
jieba_op.add_word(None)
except ValueError:
pass
def test_jieba_2_3():
"""Test add_word with freq, the value of freq affects the result of segmentation"""
DATA_FILE4 = "../data/dataset/testJiebaDataset/6.txt"
data = ds.TextFileDataset(DATA_FILE4)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op.add_word("江大桥", 20000)
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=2)
expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def test_jieba_3():
"""Test add_dict with dict"""
DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt"
user_dict = {
"男默女泪": 10
}
data = ds.TextFileDataset(DATA_FILE4)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op.add_dict(user_dict)
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=1)
expect = ['男默女泪', '市', '长江大桥']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def test_jieba_3_1():
"""Test add_dict with dict"""
DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt"
user_dict = {
"男默女泪": 10,
"江大桥": 20000
}
data = ds.TextFileDataset(DATA_FILE4)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op.add_dict(user_dict)
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=1)
expect = ['男默女泪', '市长', '江大桥']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def test_jieba_4():
DATA_FILE4 = "../data/dataset/testJiebaDataset/3.txt"
DICT_FILE = "../data/dataset/testJiebaDataset/user_dict.txt"
data = ds.TextFileDataset(DATA_FILE4)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op.add_dict(DICT_FILE)
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=1)
expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def test_jieba_4_1():
"""Test add dict with invalid file path"""
DICT_FILE = ""
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
try:
jieba_op.add_dict(DICT_FILE)
except ValueError:
pass
def test_jieba_5():
"""Test add dict with file path"""
DATA_FILE4 = "../data/dataset/testJiebaDataset/6.txt"
data = ds.TextFileDataset(DATA_FILE4)
jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op.add_word("江大桥", 20000)
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=1)
expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
def gen():
text = np.array("今天天气太好了我们一起去外面玩吧".encode("UTF8"), dtype='S')
yield text,
def pytoken_op(input_data):
te = str(as_text(input_data))
tokens = []
tokens.append(te[:5].encode("UTF8"))
tokens.append(te[5:10].encode("UTF8"))
tokens.append(te[10:].encode("UTF8"))
return np.array(tokens, dtype='S')
def test_jieba_6():
data = ds.GeneratorDataset(gen, column_names=["text"])
data = data.map(input_columns=["text"],
operations=pytoken_op, num_parallel_workers=1)
expect = ['今天天气太', '好了我们一', '起去外面玩吧']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
if __name__ == "__main__":
test_jieba_1()
test_jieba_1_1()
test_jieba_1_2()
test_jieba_2()
test_jieba_2_1()
test_jieba_2_2()
test_jieba_3()
test_jieba_3_1()
test_jieba_4()
test_jieba_4_1()
test_jieba_5()
test_jieba_5()
test_jieba_6()
diff -Npur cppjieba/include/cppjieba/Jieba.hpp cppjiebap/include/cppjieba/Jieba.hpp
--- cppjieba/include/cppjieba/Jieba.hpp 2020-05-07 15:27:16.490147073 +0800
+++ cppjiebap/include/cppjieba/Jieba.hpp 2020-05-07 15:51:15.315931163 +0800
@@ -10,17 +10,14 @@ class Jieba {
public:
Jieba(const string& dict_path,
const string& model_path,
- const string& user_dict_path,
- const string& idfPath,
- const string& stopWordPath)
+ const string& user_dict_path)
: dict_trie_(dict_path, user_dict_path),
model_(model_path),
mp_seg_(&dict_trie_),
hmm_seg_(&model_),
mix_seg_(&dict_trie_, &model_),
full_seg_(&dict_trie_),
- query_seg_(&dict_trie_, &model_),
- extractor(&dict_trie_, &model_, idfPath, stopWordPath) {
+ query_seg_(&dict_trie_, &model_) {
}
~Jieba() {
}
@@ -121,8 +118,6 @@ class Jieba {
FullSegment full_seg_;
QuerySegment query_seg_;
- public:
- KeywordExtractor extractor;
}; // class Jieba
} // namespace cppjieba
diff -Npur cppjieba/test/demo.cpp cppjiebap/test/demo.cpp
--- cppjieba/test/demo.cpp 2020-05-07 15:27:16.490147073 +0800
+++ cppjiebap/test/demo.cpp 2020-05-07 15:53:21.630248552 +0800
@@ -11,9 +11,7 @@ const char* const STOP_WORD_PATH = "../d
int main(int argc, char** argv) {
cppjieba::Jieba jieba(DICT_PATH,
HMM_PATH,
- USER_DICT_PATH,
- IDF_PATH,
- STOP_WORD_PATH);
+ USER_DICT_PATH);
vector<string> words;
vector<cppjieba::Word> jiebawords;
string s;
@@ -71,10 +69,5 @@ int main(int argc, char** argv) {
cout << tagres << endl;
cout << "[demo] Keyword Extraction" << endl;
- const size_t topk = 5;
- vector<cppjieba::KeywordExtractor::Word> keywordres;
- jieba.extractor.Extract(s, keywordres, topk);
- cout << s << endl;
- cout << keywordres << endl;
return EXIT_SUCCESS;
}
diff -Npur cppjieba/test/unittest/jieba_test.cpp cppjiebap/test/unittest/jieba_test.cpp
--- cppjieba/test/unittest/jieba_test.cpp 2020-05-07 15:27:16.522146752 +0800
+++ cppjiebap/test/unittest/jieba_test.cpp 2020-05-07 15:59:11.630860061 +0800
@@ -6,9 +6,7 @@ using namespace cppjieba;
TEST(JiebaTest, Test1) {
cppjieba::Jieba jieba("../dict/jieba.dict.utf8",
"../dict/hmm_model.utf8",
- "../dict/user.dict.utf8",
- "../dict/idf.utf8",
- "../dict/stop_words.utf8");
+ "../dict/user.dict.utf8");
vector<string> words;
string result;
@@ -43,9 +41,7 @@ TEST(JiebaTest, Test1) {
TEST(JiebaTest, WordTest) {
cppjieba::Jieba jieba("../dict/jieba.dict.utf8",
"../dict/hmm_model.utf8",
- "../dict/user.dict.utf8",
- "../dict/idf.utf8",
- "../dict/stop_words.utf8");
+ "../dict/user.dict.utf8");
vector<Word> words;
string result;
@@ -85,9 +81,7 @@ TEST(JiebaTest, WordTest) {
TEST(JiebaTest, InsertUserWord) {
cppjieba::Jieba jieba("../dict/jieba.dict.utf8",
"../dict/hmm_model.utf8",
- "../dict/user.dict.utf8",
- "../dict/idf.utf8",
- "../dict/stop_words.utf8");
+ "../dict/user.dict.utf8");
vector<string> words;
string result;
@@ -120,14 +114,4 @@ TEST(JiebaTest, InsertUserWord) {
jieba.Cut("同一个世界,同一个梦想", words);
result = Join(words.begin(), words.end(), "/");
ASSERT_EQ(result, "同一个世界,同一个梦想");
-
- {
- string s("一部iPhone6");
- string res;
- vector<KeywordExtractor::Word> wordweights;
- size_t topN = 5;
- jieba.extractor.Extract(s, wordweights, topN);
- res << wordweights;
- ASSERT_EQ(res, "[{\"word\": \"iPhone6\", \"offset\": [6], \"weight\": 11.7392}, {\"word\": \"\xE4\xB8\x80\xE9\x83\xA8\", \"offset\": [0], \"weight\": 6.47592}]");
- }
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册