提交 bde9f18f 编写于 作者: Z Zirui Wu

update lookup api to take in a type

ci

add test case

address some of the review cmts

address review cmts
上级 75045e3e
......@@ -121,12 +121,13 @@ PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) {
PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
(void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp")
.def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word) {
.def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word,
const DataType &data_type) {
if (vocab == nullptr) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null."));
}
if (py_word.is_none()) {
return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists);
return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists, data_type);
}
std::string word = py::reinterpret_borrow<py::str>(py_word);
WordIdType default_id = vocab->Lookup(word);
......@@ -134,7 +135,7 @@ PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError,
"default unknown token: " + word + " doesn't exist in vocab."));
}
return std::make_shared<LookupOp>(vocab, default_id);
return std::make_shared<LookupOp>(vocab, default_id, data_type);
}));
}));
......
......@@ -22,8 +22,9 @@ namespace dataset {
namespace api {
namespace text {
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token) {
auto op = std::make_shared<LookupOperation>(vocab, unknown_token);
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
const DataType &data_type) {
auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type);
if (!op->ValidateParams()) {
return nullptr;
......@@ -32,8 +33,9 @@ std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, con
}
// LookupOperation
LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token)
: vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists) {}
LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
const DataType &data_type)
: vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {}
bool LookupOperation::ValidateParams() {
if (vocab_ == nullptr) {
......@@ -54,7 +56,7 @@ bool LookupOperation::ValidateParams() {
}
std::shared_ptr<TensorOp> LookupOperation::Build() {
std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_);
std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_, data_type_);
return tensor_op;
}
......
......@@ -20,9 +20,11 @@
#include <vector>
#include <memory>
#include <string>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/text/vocab.h"
#include "mindspore/ccsrc/minddata/dataset/core/data_type.h"
namespace mindspore {
namespace dataset {
......@@ -37,15 +39,18 @@ class LookupOperation;
/// \brief Lookup operator that looks up a word to an id.
/// \param[in] vocab a Vocab object.
/// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov).
/// If unknown_token is oov, runtime error will be thrown
/// If unknown_token is oov, runtime error will be thrown.
/// \param[in] DataType type of the tensor after lookup, typically int32.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token);
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
const mindspore::dataset::DataType &data_type = DataType("int32"));
/* ####################################### Derived TensorOperation classes ################################# */
class LookupOperation : public TensorOperation {
public:
explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token);
explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
const DataType &data_type);
~LookupOperation() = default;
......@@ -57,6 +62,7 @@ class LookupOperation : public TensorOperation {
std::shared_ptr<Vocab> vocab_;
std::string unknown_token_;
int32_t default_id_;
DataType data_type_;
};
} // namespace text
} // namespace api
......
......@@ -13,15 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/text/kernels/lookup_op.h"
#include <string>
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/text/kernels/lookup_op.h"
namespace mindspore {
namespace dataset {
LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id)
: vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {}
LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id, const DataType &data_type)
: vocab_(vocab), default_id_(default_id), type_(data_type) {}
Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
......@@ -37,6 +38,14 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
"Lookup Error: token: " + std::string(*itr) + " doesn't exist in vocab and no unknown token is specified.");
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(word_ids, input->shape(), output));
// type cast to user's requirements if what user wants isn't int32_t
if ((*output)->type() != type_) {
std::shared_ptr<Tensor> cast_to;
RETURN_IF_NOT_OK(TypeCast(*output, &cast_to, type_));
*output = cast_to;
}
return Status::OK();
}
Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
......
......@@ -18,9 +18,9 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_
#include <memory>
#include <vector>
#include <utility>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
......@@ -31,26 +31,27 @@ namespace mindspore {
namespace dataset {
class LookupOp : public TensorOp {
public:
// constructor for lookup, takes in a vocab object
// @param std::shared_ptr<Vocab> vocab -
// @param WordIdType default_id, id to lookup if a word is not in vocab
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = 1);
/// \brief constructor for lookup, takes in a vocab object.
/// \param[in] std::shared_ptr<Vocab> vocab - vocab used for lookup.
/// \param[in] WordIdType default_id, id to lookup if a word is not in vocab.
/// \param[in] DataType type of the tensor after lookup, mostly int32.
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id, const DataType &data_type);
~LookupOp() = default;
// perform actual lookup on each tensor
// @param const std::shared_ptr<Tensor> &input
// @param std::shared_ptr<Tensor> *output
// @return error code
/// \brief perform actual lookup on each tensor.
/// \param[in] const std::shared_ptr<Tensor> &input
/// \param[in] std::shared_ptr<Tensor> *output
/// \return[out] error code.
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
// print method
// @param std::ostream out
/// \brief print method.
/// \param[in] std::ostream out
void Print(std::ostream &out) const override;
// @param std::vector<DataType> &inputs -
// @param std::vector<DataType> &outputs -
// @return error code
/// \param[in] std::vector<DataType> &inputs -
/// \param[in] std::vector<DataType> &outputs -
/// \return[out] error code.
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
std::string Name() const override { return kLookupOp; }
......
......@@ -49,6 +49,7 @@ import platform
import numpy as np
import mindspore._c_dataengine as cde
import mindspore.common.dtype as mstype
from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType
from .validators import check_lookup, check_jieba_add_dict, \
......@@ -66,11 +67,12 @@ class Lookup(cde.LookupOp):
vocab(Vocab): a Vocab object.
unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov).
If unknown_token is oov, runtime error will be thrown (default=None).
data_type (mindspore.dtype, optional): mindspore.dtype lookup maps string to (default=mstype.int32)
"""
@check_lookup
def __init__(self, vocab, unknown_token=None):
super().__init__(vocab, unknown_token)
def __init__(self, vocab, unknown_token=None, data_type=mstype.int32):
super().__init__(vocab, unknown_token, mstype_to_detype(data_type))
class SlidingWindow(cde.SlidingWindowOp):
......@@ -103,7 +105,6 @@ class SlidingWindow(cde.SlidingWindowOp):
super().__init__(width, axis)
class Ngram(cde.NgramOp):
"""
TensorOp to generate n-gram from a 1-D string Tensor.
......
......@@ -44,12 +44,13 @@ def check_lookup(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[vocab, unknown_token], _ = parse_user_args(method, *args, **kwargs)
[vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs)
if unknown_token is not None:
type_check(unknown_token, (str,), "unknown_token")
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
type_check(data_type, (typing.Type,), "data_type")
return method(self, *args, **kwargs)
......@@ -327,6 +328,7 @@ def check_from_dataset(method):
return new_method
def check_slidingwindow(method):
"""A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
......@@ -339,6 +341,7 @@ def check_slidingwindow(method):
return new_method
def check_ngram(method):
"""A wrapper that wraps a parameter checker to the original function."""
......
......@@ -26,9 +26,10 @@
#include "minddata/dataset/include/text.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::DataType;
using mindspore::dataset::ShuffleMode;
using mindspore::dataset::Tensor;
using mindspore::dataset::Status;
using mindspore::dataset::Tensor;
using mindspore::dataset::Vocab;
class MindDataTestPipeline : public UT::DatasetOpTesting {
......@@ -50,7 +51,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOp) {
EXPECT_EQ(s, Status::OK());
// Create Lookup operation on ds
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>");
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32"));
EXPECT_NE(lookup, nullptr);
// Create Map operation on ds
......@@ -94,7 +95,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail1) {
// Create lookup op for ds
// Expected failure: "<unk>" is not a word of vocab
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>");
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32"));
EXPECT_EQ(lookup, nullptr);
}
......@@ -105,7 +106,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail2) {
// Create lookup op
// Expected failure: vocab is null
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "");
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "", DataType("int32"));
EXPECT_EQ(lookup, nullptr);
}
......@@ -126,7 +127,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpWithEmptyUnknownToken) {
// Create Lookup operation on ds
// Expected failure: "" is not a word of vocab
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "");
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "", DataType("int32"));
EXPECT_EQ(lookup, nullptr);
}
......@@ -148,7 +149,7 @@ TEST_F(MindDataTestPipeline, TestVocabFromDataset) {
EXPECT_EQ(home_index, 4);
// Create Lookup operation on ds
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>");
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32"));
EXPECT_NE(lookup, nullptr);
// Create Map operation on ds
......@@ -212,12 +213,15 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetDefault) {
uint64_t i = 0;
std::vector<int32_t> expected = {2, 3, 1, 4, 5, 0};
std::vector<int64_t> not_expected = {2, 3, 1, 4, 5, 0};
while (row.size() != 0) {
auto ind = row["text"];
MS_LOG(INFO) << ind->shape() << " " << *ind;
std::shared_ptr<Tensor> expected_item;
std::shared_ptr<Tensor> expected_item, not_expected_item;
Tensor::CreateScalar(expected[i], &expected_item);
Tensor::CreateScalar(not_expected[i], &not_expected_item);
EXPECT_EQ(*ind, *expected_item);
EXPECT_NE(*ind, *not_expected_item);
iter->GetNextRow(&row);
i++;
}
......@@ -233,8 +237,8 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail1) {
// Create vocab from dataset
// Expected failure: top_k can not be negative
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()},
-2, {"<pad>", "<unk>"}, true);
std::shared_ptr<Vocab> vocab =
ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()}, -2, {"<pad>", "<unk>"}, true);
EXPECT_EQ(vocab, nullptr);
}
......@@ -247,9 +251,9 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail2) {
EXPECT_NE(ds, nullptr);
// Create vocab from dataset
// Expected failure: requency_range [a,b] should be 0 <= a <= b
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {4, 1},
std::numeric_limits<int64_t>::max(), {"<pad>", "<unk>"}, true);
// Expected failure: frequency_range [a,b] should be 0 <= a <= b
std::shared_ptr<Vocab> vocab =
ds->BuildVocab({"text"}, {4, 1}, std::numeric_limits<int64_t>::max(), {"<pad>", "<unk>"}, true);
EXPECT_EQ(vocab, nullptr);
}
......@@ -266,3 +270,52 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) {
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"ColumnNotExist"});
EXPECT_EQ(vocab, nullptr);
}
TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetInt64.";
// Create a TextFile dataset
std::string data_file = datasets_root_path_ + "/testVocab/words.txt";
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create vocab from dataset
std::shared_ptr<Vocab> vocab = ds->BuildVocab();
EXPECT_NE(vocab, nullptr);
// Check if vocab has words or not
int32_t home_index = vocab->Lookup("home");
EXPECT_EQ(home_index, 2);
// Create Lookup operation on ds
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "home", DataType("int64"));
EXPECT_NE(lookup, nullptr);
// Create Map operation on ds
ds = ds->Map({lookup});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
std::vector<int64_t> expected = {2, 3, 1, 4, 5, 0};
std::vector<int8_t> not_expected = {2, 3, 1, 4, 5, 0};
while (row.size() != 0) {
auto ind = row["text"];
MS_LOG(INFO) << ind->shape() << " " << *ind;
std::shared_ptr<Tensor> expected_item, not_expected_item;
Tensor::CreateScalar(expected[i], &expected_item);
Tensor::CreateScalar(not_expected[i], &not_expected_item);
EXPECT_EQ(*ind, *expected_item);
EXPECT_NE(*ind, *not_expected_item);
iter->GetNextRow(&row);
i++;
}
}
\ No newline at end of file
......@@ -17,6 +17,7 @@ import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.text as text
import mindspore.common.dtype as mstype
# this file contains "home is behind the world head" each word is 1 line
DATA_FILE = "../data/dataset/testVocab/words.txt"
......@@ -137,6 +138,36 @@ def test_from_file():
assert "Input vocab_size must be greater than 0" in test_config("w1 w2", 0, [], True)
assert "Input vocab_size must be greater than 0" in test_config("w1 w2", -1, [], True)
def test_lookup_cast_type():
def gen(texts):
for word in texts.split(" "):
yield (np.array(word, dtype='S'),)
def test_config(lookup_str, data_type=None):
try:
vocab = text.Vocab.from_list(["w1", "w2", "w3"], special_tokens=["<unk>"], special_first=True)
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
# if data_type is None, test the default value of data_type
op = text.Lookup(vocab, "<unk>") if data_type is None else text.Lookup(vocab, "<unk>", data_type)
data = data.map(input_columns=["text"], operations=op)
res = []
for d in data.create_dict_iterator(num_epochs=1):
res.append(d["text"])
return res[0].dtype
except (ValueError, RuntimeError, TypeError) as e:
return str(e)
# test result is correct
assert test_config("w1", mstype.int8) == np.dtype("int8")
assert test_config("w2", mstype.int32) == np.dtype("int32")
assert test_config("w3", mstype.int64) == np.dtype("int64")
assert test_config("unk", mstype.float32) != np.dtype("int32")
assert test_config("unk") == np.dtype("int32")
# test exception, data_type isn't the correct type
assert "tldr is not of type (<class 'mindspore._c_expression.typing.Type'>,)" in test_config("unk", "tldr")
if __name__ == '__main__':
test_from_dict_exception()
test_from_list_tutorial()
......@@ -144,3 +175,4 @@ if __name__ == '__main__':
test_from_dict_tutorial()
test_from_list()
test_from_file()
test_lookup_cast_type()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册