提交 bde9f18f 编写于 作者: Z Zirui Wu

update lookup api to take in a type

ci

add test case

address some of the review cmts

address review cmts
上级 75045e3e
...@@ -121,12 +121,13 @@ PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) { ...@@ -121,12 +121,13 @@ PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) {
PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) { PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
(void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp") (void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp")
.def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word) { .def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word,
const DataType &data_type) {
if (vocab == nullptr) { if (vocab == nullptr) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null.")); THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null."));
} }
if (py_word.is_none()) { if (py_word.is_none()) {
return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists); return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists, data_type);
} }
std::string word = py::reinterpret_borrow<py::str>(py_word); std::string word = py::reinterpret_borrow<py::str>(py_word);
WordIdType default_id = vocab->Lookup(word); WordIdType default_id = vocab->Lookup(word);
...@@ -134,7 +135,7 @@ PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) { ...@@ -134,7 +135,7 @@ PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, THROW_IF_ERROR(Status(StatusCode::kUnexpectedError,
"default unknown token: " + word + " doesn't exist in vocab.")); "default unknown token: " + word + " doesn't exist in vocab."));
} }
return std::make_shared<LookupOp>(vocab, default_id); return std::make_shared<LookupOp>(vocab, default_id, data_type);
})); }));
})); }));
......
...@@ -22,8 +22,9 @@ namespace dataset { ...@@ -22,8 +22,9 @@ namespace dataset {
namespace api { namespace api {
namespace text { namespace text {
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token) { std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
auto op = std::make_shared<LookupOperation>(vocab, unknown_token); const DataType &data_type) {
auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type);
if (!op->ValidateParams()) { if (!op->ValidateParams()) {
return nullptr; return nullptr;
...@@ -32,8 +33,9 @@ std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, con ...@@ -32,8 +33,9 @@ std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, con
} }
// LookupOperation // LookupOperation
LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token) LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
: vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists) {} const DataType &data_type)
: vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {}
bool LookupOperation::ValidateParams() { bool LookupOperation::ValidateParams() {
if (vocab_ == nullptr) { if (vocab_ == nullptr) {
...@@ -54,7 +56,7 @@ bool LookupOperation::ValidateParams() { ...@@ -54,7 +56,7 @@ bool LookupOperation::ValidateParams() {
} }
std::shared_ptr<TensorOp> LookupOperation::Build() { std::shared_ptr<TensorOp> LookupOperation::Build() {
std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_); std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_, data_type_);
return tensor_op; return tensor_op;
} }
......
...@@ -20,9 +20,11 @@ ...@@ -20,9 +20,11 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/constants.h"
#include "minddata/dataset/include/transforms.h" #include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/text/vocab.h" #include "minddata/dataset/text/vocab.h"
#include "mindspore/ccsrc/minddata/dataset/core/data_type.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
...@@ -37,15 +39,18 @@ class LookupOperation; ...@@ -37,15 +39,18 @@ class LookupOperation;
/// \brief Lookup operator that looks up a word to an id. /// \brief Lookup operator that looks up a word to an id.
/// \param[in] vocab a Vocab object. /// \param[in] vocab a Vocab object.
/// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov). /// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov).
/// If unknown_token is oov, runtime error will be thrown /// If unknown_token is oov, runtime error will be thrown.
/// \param[in] DataType type of the tensor after lookup, typically int32.
/// \return Shared pointer to the current TensorOperation. /// \return Shared pointer to the current TensorOperation.
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token); std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
const mindspore::dataset::DataType &data_type = DataType("int32"));
/* ####################################### Derived TensorOperation classes ################################# */ /* ####################################### Derived TensorOperation classes ################################# */
class LookupOperation : public TensorOperation { class LookupOperation : public TensorOperation {
public: public:
explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token); explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
const DataType &data_type);
~LookupOperation() = default; ~LookupOperation() = default;
...@@ -57,6 +62,7 @@ class LookupOperation : public TensorOperation { ...@@ -57,6 +62,7 @@ class LookupOperation : public TensorOperation {
std::shared_ptr<Vocab> vocab_; std::shared_ptr<Vocab> vocab_;
std::string unknown_token_; std::string unknown_token_;
int32_t default_id_; int32_t default_id_;
DataType data_type_;
}; };
} // namespace text } // namespace text
} // namespace api } // namespace api
......
...@@ -13,15 +13,16 @@ ...@@ -13,15 +13,16 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/text/kernels/lookup_op.h"
#include <string> #include <string>
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/text/kernels/lookup_op.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id) LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id, const DataType &data_type)
: vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} : vocab_(vocab), default_id_(default_id), type_(data_type) {}
Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output); IO_CHECK(input, output);
...@@ -37,6 +38,14 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T ...@@ -37,6 +38,14 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
"Lookup Error: token: " + std::string(*itr) + " doesn't exist in vocab and no unknown token is specified."); "Lookup Error: token: " + std::string(*itr) + " doesn't exist in vocab and no unknown token is specified.");
} }
RETURN_IF_NOT_OK(Tensor::CreateFromVector(word_ids, input->shape(), output)); RETURN_IF_NOT_OK(Tensor::CreateFromVector(word_ids, input->shape(), output));
// type cast to user's requirements if what user wants isn't int32_t
if ((*output)->type() != type_) {
std::shared_ptr<Tensor> cast_to;
RETURN_IF_NOT_OK(TypeCast(*output, &cast_to, type_));
*output = cast_to;
}
return Status::OK(); return Status::OK();
} }
Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_
#include <memory> #include <memory>
#include <vector>
#include <utility>
#include <string> #include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/kernels/tensor_op.h"
...@@ -31,26 +31,27 @@ namespace mindspore { ...@@ -31,26 +31,27 @@ namespace mindspore {
namespace dataset { namespace dataset {
class LookupOp : public TensorOp { class LookupOp : public TensorOp {
public: public:
// constructor for lookup, takes in a vocab object /// \brief constructor for lookup, takes in a vocab object.
// @param std::shared_ptr<Vocab> vocab - /// \param[in] std::shared_ptr<Vocab> vocab - vocab used for lookup.
// @param WordIdType default_id, id to lookup if a word is not in vocab /// \param[in] WordIdType default_id, id to lookup if a word is not in vocab.
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = 1); /// \param[in] DataType type of the tensor after lookup, mostly int32.
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id, const DataType &data_type);
~LookupOp() = default; ~LookupOp() = default;
// perform actual lookup on each tensor /// \brief perform actual lookup on each tensor.
// @param const std::shared_ptr<Tensor> &input /// \param[in] const std::shared_ptr<Tensor> &input
// @param std::shared_ptr<Tensor> *output /// \param[in] std::shared_ptr<Tensor> *output
// @return error code /// \return[out] error code.
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
// print method /// \brief print method.
// @param std::ostream out /// \param[in] std::ostream out
void Print(std::ostream &out) const override; void Print(std::ostream &out) const override;
// @param std::vector<DataType> &inputs - /// \param[in] std::vector<DataType> &inputs -
// @param std::vector<DataType> &outputs - /// \param[in] std::vector<DataType> &outputs -
// @return error code /// \return[out] error code.
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
std::string Name() const override { return kLookupOp; } std::string Name() const override { return kLookupOp; }
......
...@@ -49,6 +49,7 @@ import platform ...@@ -49,6 +49,7 @@ import platform
import numpy as np import numpy as np
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
import mindspore.common.dtype as mstype
from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType
from .validators import check_lookup, check_jieba_add_dict, \ from .validators import check_lookup, check_jieba_add_dict, \
...@@ -66,11 +67,12 @@ class Lookup(cde.LookupOp): ...@@ -66,11 +67,12 @@ class Lookup(cde.LookupOp):
vocab(Vocab): a Vocab object. vocab(Vocab): a Vocab object.
unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov). unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov).
If unknown_token is oov, runtime error will be thrown (default=None). If unknown_token is oov, runtime error will be thrown (default=None).
data_type (mindspore.dtype, optional): mindspore.dtype lookup maps string to (default=mstype.int32)
""" """
@check_lookup @check_lookup
def __init__(self, vocab, unknown_token=None): def __init__(self, vocab, unknown_token=None, data_type=mstype.int32):
super().__init__(vocab, unknown_token) super().__init__(vocab, unknown_token, mstype_to_detype(data_type))
class SlidingWindow(cde.SlidingWindowOp): class SlidingWindow(cde.SlidingWindowOp):
...@@ -103,7 +105,6 @@ class SlidingWindow(cde.SlidingWindowOp): ...@@ -103,7 +105,6 @@ class SlidingWindow(cde.SlidingWindowOp):
super().__init__(width, axis) super().__init__(width, axis)
class Ngram(cde.NgramOp): class Ngram(cde.NgramOp):
""" """
TensorOp to generate n-gram from a 1-D string Tensor. TensorOp to generate n-gram from a 1-D string Tensor.
......
...@@ -44,12 +44,13 @@ def check_lookup(method): ...@@ -44,12 +44,13 @@ def check_lookup(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[vocab, unknown_token], _ = parse_user_args(method, *args, **kwargs) [vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs)
if unknown_token is not None: if unknown_token is not None:
type_check(unknown_token, (str,), "unknown_token") type_check(unknown_token, (str,), "unknown_token")
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.") type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
type_check(data_type, (typing.Type,), "data_type")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
...@@ -327,6 +328,7 @@ def check_from_dataset(method): ...@@ -327,6 +328,7 @@ def check_from_dataset(method):
return new_method return new_method
def check_slidingwindow(method): def check_slidingwindow(method):
"""A wrapper that wraps a parameter checker to the original function(sliding window operation).""" """A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
...@@ -339,6 +341,7 @@ def check_slidingwindow(method): ...@@ -339,6 +341,7 @@ def check_slidingwindow(method):
return new_method return new_method
def check_ngram(method): def check_ngram(method):
"""A wrapper that wraps a parameter checker to the original function.""" """A wrapper that wraps a parameter checker to the original function."""
......
...@@ -26,9 +26,10 @@ ...@@ -26,9 +26,10 @@
#include "minddata/dataset/include/text.h" #include "minddata/dataset/include/text.h"
using namespace mindspore::dataset::api; using namespace mindspore::dataset::api;
using mindspore::dataset::DataType;
using mindspore::dataset::ShuffleMode; using mindspore::dataset::ShuffleMode;
using mindspore::dataset::Tensor;
using mindspore::dataset::Status; using mindspore::dataset::Status;
using mindspore::dataset::Tensor;
using mindspore::dataset::Vocab; using mindspore::dataset::Vocab;
class MindDataTestPipeline : public UT::DatasetOpTesting { class MindDataTestPipeline : public UT::DatasetOpTesting {
...@@ -50,7 +51,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOp) { ...@@ -50,7 +51,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOp) {
EXPECT_EQ(s, Status::OK()); EXPECT_EQ(s, Status::OK());
// Create Lookup operation on ds // Create Lookup operation on ds
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>"); std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32"));
EXPECT_NE(lookup, nullptr); EXPECT_NE(lookup, nullptr);
// Create Map operation on ds // Create Map operation on ds
...@@ -94,7 +95,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail1) { ...@@ -94,7 +95,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail1) {
// Create lookup op for ds // Create lookup op for ds
// Expected failure: "<unk>" is not a word of vocab // Expected failure: "<unk>" is not a word of vocab
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>"); std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32"));
EXPECT_EQ(lookup, nullptr); EXPECT_EQ(lookup, nullptr);
} }
...@@ -105,7 +106,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail2) { ...@@ -105,7 +106,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail2) {
// Create lookup op // Create lookup op
// Expected failure: vocab is null // Expected failure: vocab is null
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, ""); std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "", DataType("int32"));
EXPECT_EQ(lookup, nullptr); EXPECT_EQ(lookup, nullptr);
} }
...@@ -126,7 +127,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpWithEmptyUnknownToken) { ...@@ -126,7 +127,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpWithEmptyUnknownToken) {
// Create Lookup operation on ds // Create Lookup operation on ds
// Expected failure: "" is not a word of vocab // Expected failure: "" is not a word of vocab
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, ""); std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "", DataType("int32"));
EXPECT_EQ(lookup, nullptr); EXPECT_EQ(lookup, nullptr);
} }
...@@ -148,7 +149,7 @@ TEST_F(MindDataTestPipeline, TestVocabFromDataset) { ...@@ -148,7 +149,7 @@ TEST_F(MindDataTestPipeline, TestVocabFromDataset) {
EXPECT_EQ(home_index, 4); EXPECT_EQ(home_index, 4);
// Create Lookup operation on ds // Create Lookup operation on ds
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>"); std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32"));
EXPECT_NE(lookup, nullptr); EXPECT_NE(lookup, nullptr);
// Create Map operation on ds // Create Map operation on ds
...@@ -212,12 +213,15 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetDefault) { ...@@ -212,12 +213,15 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetDefault) {
uint64_t i = 0; uint64_t i = 0;
std::vector<int32_t> expected = {2, 3, 1, 4, 5, 0}; std::vector<int32_t> expected = {2, 3, 1, 4, 5, 0};
std::vector<int64_t> not_expected = {2, 3, 1, 4, 5, 0};
while (row.size() != 0) { while (row.size() != 0) {
auto ind = row["text"]; auto ind = row["text"];
MS_LOG(INFO) << ind->shape() << " " << *ind; MS_LOG(INFO) << ind->shape() << " " << *ind;
std::shared_ptr<Tensor> expected_item; std::shared_ptr<Tensor> expected_item, not_expected_item;
Tensor::CreateScalar(expected[i], &expected_item); Tensor::CreateScalar(expected[i], &expected_item);
Tensor::CreateScalar(not_expected[i], &not_expected_item);
EXPECT_EQ(*ind, *expected_item); EXPECT_EQ(*ind, *expected_item);
EXPECT_NE(*ind, *not_expected_item);
iter->GetNextRow(&row); iter->GetNextRow(&row);
i++; i++;
} }
...@@ -233,8 +237,8 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail1) { ...@@ -233,8 +237,8 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail1) {
// Create vocab from dataset // Create vocab from dataset
// Expected failure: top_k can not be negative // Expected failure: top_k can not be negative
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()}, std::shared_ptr<Vocab> vocab =
-2, {"<pad>", "<unk>"}, true); ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()}, -2, {"<pad>", "<unk>"}, true);
EXPECT_EQ(vocab, nullptr); EXPECT_EQ(vocab, nullptr);
} }
...@@ -247,9 +251,9 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail2) { ...@@ -247,9 +251,9 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail2) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
// Create vocab from dataset // Create vocab from dataset
// Expected failure: requency_range [a,b] should be 0 <= a <= b // Expected failure: frequency_range [a,b] should be 0 <= a <= b
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {4, 1}, std::shared_ptr<Vocab> vocab =
std::numeric_limits<int64_t>::max(), {"<pad>", "<unk>"}, true); ds->BuildVocab({"text"}, {4, 1}, std::numeric_limits<int64_t>::max(), {"<pad>", "<unk>"}, true);
EXPECT_EQ(vocab, nullptr); EXPECT_EQ(vocab, nullptr);
} }
...@@ -266,3 +270,52 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) { ...@@ -266,3 +270,52 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) {
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"ColumnNotExist"}); std::shared_ptr<Vocab> vocab = ds->BuildVocab({"ColumnNotExist"});
EXPECT_EQ(vocab, nullptr); EXPECT_EQ(vocab, nullptr);
} }
TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetInt64.";
// Create a TextFile dataset
std::string data_file = datasets_root_path_ + "/testVocab/words.txt";
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create vocab from dataset
std::shared_ptr<Vocab> vocab = ds->BuildVocab();
EXPECT_NE(vocab, nullptr);
// Check if vocab has words or not
int32_t home_index = vocab->Lookup("home");
EXPECT_EQ(home_index, 2);
// Create Lookup operation on ds
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "home", DataType("int64"));
EXPECT_NE(lookup, nullptr);
// Create Map operation on ds
ds = ds->Map({lookup});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
std::vector<int64_t> expected = {2, 3, 1, 4, 5, 0};
std::vector<int8_t> not_expected = {2, 3, 1, 4, 5, 0};
while (row.size() != 0) {
auto ind = row["text"];
MS_LOG(INFO) << ind->shape() << " " << *ind;
std::shared_ptr<Tensor> expected_item, not_expected_item;
Tensor::CreateScalar(expected[i], &expected_item);
Tensor::CreateScalar(not_expected[i], &not_expected_item);
EXPECT_EQ(*ind, *expected_item);
EXPECT_NE(*ind, *not_expected_item);
iter->GetNextRow(&row);
i++;
}
}
\ No newline at end of file
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.text as text import mindspore.dataset.text as text
import mindspore.common.dtype as mstype
# this file contains "home is behind the world head" each word is 1 line # this file contains "home is behind the world head" each word is 1 line
DATA_FILE = "../data/dataset/testVocab/words.txt" DATA_FILE = "../data/dataset/testVocab/words.txt"
...@@ -137,6 +138,36 @@ def test_from_file(): ...@@ -137,6 +138,36 @@ def test_from_file():
assert "Input vocab_size must be greater than 0" in test_config("w1 w2", 0, [], True) assert "Input vocab_size must be greater than 0" in test_config("w1 w2", 0, [], True)
assert "Input vocab_size must be greater than 0" in test_config("w1 w2", -1, [], True) assert "Input vocab_size must be greater than 0" in test_config("w1 w2", -1, [], True)
def test_lookup_cast_type():
def gen(texts):
for word in texts.split(" "):
yield (np.array(word, dtype='S'),)
def test_config(lookup_str, data_type=None):
try:
vocab = text.Vocab.from_list(["w1", "w2", "w3"], special_tokens=["<unk>"], special_first=True)
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
# if data_type is None, test the default value of data_type
op = text.Lookup(vocab, "<unk>") if data_type is None else text.Lookup(vocab, "<unk>", data_type)
data = data.map(input_columns=["text"], operations=op)
res = []
for d in data.create_dict_iterator(num_epochs=1):
res.append(d["text"])
return res[0].dtype
except (ValueError, RuntimeError, TypeError) as e:
return str(e)
# test result is correct
assert test_config("w1", mstype.int8) == np.dtype("int8")
assert test_config("w2", mstype.int32) == np.dtype("int32")
assert test_config("w3", mstype.int64) == np.dtype("int64")
assert test_config("unk", mstype.float32) != np.dtype("int32")
assert test_config("unk") == np.dtype("int32")
# test exception, data_type isn't the correct type
assert "tldr is not of type (<class 'mindspore._c_expression.typing.Type'>,)" in test_config("unk", "tldr")
if __name__ == '__main__': if __name__ == '__main__':
test_from_dict_exception() test_from_dict_exception()
test_from_list_tutorial() test_from_list_tutorial()
...@@ -144,3 +175,4 @@ if __name__ == '__main__': ...@@ -144,3 +175,4 @@ if __name__ == '__main__':
test_from_dict_tutorial() test_from_dict_tutorial()
test_from_list() test_from_list()
test_from_file() test_from_file()
test_lookup_cast_type()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册