Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindspore
提交
bde9f18f
M
mindspore
项目概览
MindSpore
/
mindspore
通知
35
Star
15
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
bde9f18f
编写于
9月 04, 2020
作者:
Z
Zirui Wu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update lookup api to take in a type
ci add test case address some of the review cmts address review cmts
上级
75045e3e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
154 addition
and
46 deletion
+154
-46
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/bindings.cc
...aset/api/python/bindings/dataset/text/kernels/bindings.cc
+4
-3
mindspore/ccsrc/minddata/dataset/api/text.cc
mindspore/ccsrc/minddata/dataset/api/text.cc
+7
-5
mindspore/ccsrc/minddata/dataset/include/text.h
mindspore/ccsrc/minddata/dataset/include/text.h
+9
-3
mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc
mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc
+13
-4
mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h
mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h
+16
-15
mindspore/dataset/text/transforms.py
mindspore/dataset/text/transforms.py
+4
-3
mindspore/dataset/text/validators.py
mindspore/dataset/text/validators.py
+4
-1
tests/ut/cpp/dataset/c_api_dataset_vocab.cc
tests/ut/cpp/dataset/c_api_dataset_vocab.cc
+65
-12
tests/ut/python/dataset/test_vocab.py
tests/ut/python/dataset/test_vocab.py
+32
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/bindings.cc
浏览文件 @
bde9f18f
...
...
@@ -121,12 +121,13 @@ PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) {
PYBIND_REGISTER
(
LookupOp
,
1
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
LookupOp
,
TensorOp
,
std
::
shared_ptr
<
LookupOp
>>
(
*
m
,
"LookupOp"
)
.
def
(
py
::
init
([](
std
::
shared_ptr
<
Vocab
>
vocab
,
const
py
::
object
&
py_word
)
{
.
def
(
py
::
init
([](
std
::
shared_ptr
<
Vocab
>
vocab
,
const
py
::
object
&
py_word
,
const
DataType
&
data_type
)
{
if
(
vocab
==
nullptr
)
{
THROW_IF_ERROR
(
Status
(
StatusCode
::
kUnexpectedError
,
"vocab object type is incorrect or null."
));
}
if
(
py_word
.
is_none
())
{
return
std
::
make_shared
<
LookupOp
>
(
vocab
,
Vocab
::
kNoTokenExists
);
return
std
::
make_shared
<
LookupOp
>
(
vocab
,
Vocab
::
kNoTokenExists
,
data_type
);
}
std
::
string
word
=
py
::
reinterpret_borrow
<
py
::
str
>
(
py_word
);
WordIdType
default_id
=
vocab
->
Lookup
(
word
);
...
...
@@ -134,7 +135,7 @@ PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
THROW_IF_ERROR
(
Status
(
StatusCode
::
kUnexpectedError
,
"default unknown token: "
+
word
+
" doesn't exist in vocab."
));
}
return
std
::
make_shared
<
LookupOp
>
(
vocab
,
default_id
);
return
std
::
make_shared
<
LookupOp
>
(
vocab
,
default_id
,
data_type
);
}));
}));
...
...
mindspore/ccsrc/minddata/dataset/api/text.cc
浏览文件 @
bde9f18f
...
...
@@ -22,8 +22,9 @@ namespace dataset {
namespace
api
{
namespace
text
{
std
::
shared_ptr
<
LookupOperation
>
Lookup
(
const
std
::
shared_ptr
<
Vocab
>
&
vocab
,
const
std
::
string
&
unknown_token
)
{
auto
op
=
std
::
make_shared
<
LookupOperation
>
(
vocab
,
unknown_token
);
std
::
shared_ptr
<
LookupOperation
>
Lookup
(
const
std
::
shared_ptr
<
Vocab
>
&
vocab
,
const
std
::
string
&
unknown_token
,
const
DataType
&
data_type
)
{
auto
op
=
std
::
make_shared
<
LookupOperation
>
(
vocab
,
unknown_token
,
data_type
);
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -32,8 +33,9 @@ std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, con
}
// LookupOperation
LookupOperation
::
LookupOperation
(
const
std
::
shared_ptr
<
Vocab
>
&
vocab
,
const
std
::
string
&
unknown_token
)
:
vocab_
(
vocab
),
unknown_token_
(
unknown_token
),
default_id_
(
Vocab
::
kNoTokenExists
)
{}
LookupOperation
::
LookupOperation
(
const
std
::
shared_ptr
<
Vocab
>
&
vocab
,
const
std
::
string
&
unknown_token
,
const
DataType
&
data_type
)
:
vocab_
(
vocab
),
unknown_token_
(
unknown_token
),
default_id_
(
Vocab
::
kNoTokenExists
),
data_type_
(
data_type
)
{}
bool
LookupOperation
::
ValidateParams
()
{
if
(
vocab_
==
nullptr
)
{
...
...
@@ -54,7 +56,7 @@ bool LookupOperation::ValidateParams() {
}
std
::
shared_ptr
<
TensorOp
>
LookupOperation
::
Build
()
{
std
::
shared_ptr
<
LookupOp
>
tensor_op
=
std
::
make_shared
<
LookupOp
>
(
vocab_
,
default_id_
);
std
::
shared_ptr
<
LookupOp
>
tensor_op
=
std
::
make_shared
<
LookupOp
>
(
vocab_
,
default_id_
,
data_type_
);
return
tensor_op
;
}
...
...
mindspore/ccsrc/minddata/dataset/include/text.h
浏览文件 @
bde9f18f
...
...
@@ -20,9 +20,11 @@
#include <vector>
#include <memory>
#include <string>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/text/vocab.h"
#include "mindspore/ccsrc/minddata/dataset/core/data_type.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -37,15 +39,18 @@ class LookupOperation;
/// \brief Lookup operator that looks up a word to an id.
/// \param[in] vocab a Vocab object.
/// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov).
/// If unknown_token is oov, runtime error will be thrown
/// If unknown_token is oov, runtime error will be thrown.
/// \param[in] DataType type of the tensor after lookup, typically int32.
/// \return Shared pointer to the current TensorOperation.
std
::
shared_ptr
<
LookupOperation
>
Lookup
(
const
std
::
shared_ptr
<
Vocab
>
&
vocab
,
const
std
::
string
&
unknown_token
);
std
::
shared_ptr
<
LookupOperation
>
Lookup
(
const
std
::
shared_ptr
<
Vocab
>
&
vocab
,
const
std
::
string
&
unknown_token
,
const
mindspore
::
dataset
::
DataType
&
data_type
=
DataType
(
"int32"
));
/* ####################################### Derived TensorOperation classes ################################# */
class
LookupOperation
:
public
TensorOperation
{
public:
explicit
LookupOperation
(
const
std
::
shared_ptr
<
Vocab
>
&
vocab
,
const
std
::
string
&
unknown_token
);
explicit
LookupOperation
(
const
std
::
shared_ptr
<
Vocab
>
&
vocab
,
const
std
::
string
&
unknown_token
,
const
DataType
&
data_type
);
~
LookupOperation
()
=
default
;
...
...
@@ -57,6 +62,7 @@ class LookupOperation : public TensorOperation {
std
::
shared_ptr
<
Vocab
>
vocab_
;
std
::
string
unknown_token_
;
int32_t
default_id_
;
DataType
data_type_
;
};
}
// namespace text
}
// namespace api
...
...
mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc
浏览文件 @
bde9f18f
...
...
@@ -13,15 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/text/kernels/lookup_op.h"
#include <string>
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/text/kernels/lookup_op.h"
namespace
mindspore
{
namespace
dataset
{
LookupOp
::
LookupOp
(
std
::
shared_ptr
<
Vocab
>
vocab
,
WordIdType
default_id
)
:
vocab_
(
vocab
),
default_id_
(
default_id
),
type_
(
DataType
(
"int32"
)
)
{}
LookupOp
::
LookupOp
(
std
::
shared_ptr
<
Vocab
>
vocab
,
WordIdType
default_id
,
const
DataType
&
data_type
)
:
vocab_
(
vocab
),
default_id_
(
default_id
),
type_
(
data_type
)
{}
Status
LookupOp
::
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
{
IO_CHECK
(
input
,
output
);
...
...
@@ -37,6 +38,14 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
"Lookup Error: token: "
+
std
::
string
(
*
itr
)
+
" doesn't exist in vocab and no unknown token is specified."
);
}
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromVector
(
word_ids
,
input
->
shape
(),
output
));
// type cast to user's requirements if what user wants isn't int32_t
if
((
*
output
)
->
type
()
!=
type_
)
{
std
::
shared_ptr
<
Tensor
>
cast_to
;
RETURN_IF_NOT_OK
(
TypeCast
(
*
output
,
&
cast_to
,
type_
));
*
output
=
cast_to
;
}
return
Status
::
OK
();
}
Status
LookupOp
::
OutputType
(
const
std
::
vector
<
DataType
>
&
inputs
,
std
::
vector
<
DataType
>
&
outputs
)
{
...
...
mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h
浏览文件 @
bde9f18f
...
...
@@ -18,9 +18,9 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_
#include <memory>
#include <vector>
#include <utility>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
...
...
@@ -31,26 +31,27 @@ namespace mindspore {
namespace
dataset
{
class
LookupOp
:
public
TensorOp
{
public:
// constructor for lookup, takes in a vocab object
// @param std::shared_ptr<Vocab> vocab -
// @param WordIdType default_id, id to lookup if a word is not in vocab
explicit
LookupOp
(
std
::
shared_ptr
<
Vocab
>
vocab
,
WordIdType
default_id
=
1
);
/// \brief constructor for lookup, takes in a vocab object.
/// \param[in] std::shared_ptr<Vocab> vocab - vocab used for lookup.
/// \param[in] WordIdType default_id, id to lookup if a word is not in vocab.
/// \param[in] DataType type of the tensor after lookup, mostly int32.
explicit
LookupOp
(
std
::
shared_ptr
<
Vocab
>
vocab
,
WordIdType
default_id
,
const
DataType
&
data_type
);
~
LookupOp
()
=
default
;
//
perform actual lookup on each tensor
//
@param
const std::shared_ptr<Tensor> &input
//
@param
std::shared_ptr<Tensor> *output
//
@return error code
//
/ \brief perform actual lookup on each tensor.
//
/ \param[in]
const std::shared_ptr<Tensor> &input
//
/ \param[in]
std::shared_ptr<Tensor> *output
//
/ \return[out] error code.
Status
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
override
;
//
print method
//
@param
std::ostream out
//
/ \brief print method.
//
/ \param[in]
std::ostream out
void
Print
(
std
::
ostream
&
out
)
const
override
;
//
@param
std::vector<DataType> &inputs -
//
@param
std::vector<DataType> &outputs -
//
@return error code
//
/ \param[in]
std::vector<DataType> &inputs -
//
/ \param[in]
std::vector<DataType> &outputs -
//
/ \return[out] error code.
Status
OutputType
(
const
std
::
vector
<
DataType
>
&
inputs
,
std
::
vector
<
DataType
>
&
outputs
)
override
;
std
::
string
Name
()
const
override
{
return
kLookupOp
;
}
...
...
mindspore/dataset/text/transforms.py
浏览文件 @
bde9f18f
...
...
@@ -49,6 +49,7 @@ import platform
import
numpy
as
np
import
mindspore._c_dataengine
as
cde
import
mindspore.common.dtype
as
mstype
from
.utils
import
JiebaMode
,
NormalizeForm
,
to_str
,
SPieceTokenizerOutType
,
SPieceTokenizerLoadType
from
.validators
import
check_lookup
,
check_jieba_add_dict
,
\
...
...
@@ -66,11 +67,12 @@ class Lookup(cde.LookupOp):
vocab(Vocab): a Vocab object.
unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov).
If unknown_token is oov, runtime error will be thrown (default=None).
data_type (mindspore.dtype, optional): mindspore.dtype lookup maps string to (default=mstype.int32)
"""
@
check_lookup
def
__init__
(
self
,
vocab
,
unknown_token
=
None
):
super
().
__init__
(
vocab
,
unknown_token
)
def
__init__
(
self
,
vocab
,
unknown_token
=
None
,
data_type
=
mstype
.
int32
):
super
().
__init__
(
vocab
,
unknown_token
,
mstype_to_detype
(
data_type
)
)
class
SlidingWindow
(
cde
.
SlidingWindowOp
):
...
...
@@ -103,7 +105,6 @@ class SlidingWindow(cde.SlidingWindowOp):
super
().
__init__
(
width
,
axis
)
class
Ngram
(
cde
.
NgramOp
):
"""
TensorOp to generate n-gram from a 1-D string Tensor.
...
...
mindspore/dataset/text/validators.py
浏览文件 @
bde9f18f
...
...
@@ -44,12 +44,13 @@ def check_lookup(method):
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
[
vocab
,
unknown_token
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
[
vocab
,
unknown_token
,
data_type
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
if
unknown_token
is
not
None
:
type_check
(
unknown_token
,
(
str
,),
"unknown_token"
)
type_check
(
vocab
,
(
cde
.
Vocab
,),
"vocab is not an instance of cde.Vocab."
)
type_check
(
data_type
,
(
typing
.
Type
,),
"data_type"
)
return
method
(
self
,
*
args
,
**
kwargs
)
...
...
@@ -327,6 +328,7 @@ def check_from_dataset(method):
return
new_method
def
check_slidingwindow
(
method
):
"""A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
...
...
@@ -339,6 +341,7 @@ def check_slidingwindow(method):
return
new_method
def
check_ngram
(
method
):
"""A wrapper that wraps a parameter checker to the original function."""
...
...
tests/ut/cpp/dataset/c_api_dataset_vocab.cc
浏览文件 @
bde9f18f
...
...
@@ -26,9 +26,10 @@
#include "minddata/dataset/include/text.h"
using
namespace
mindspore
::
dataset
::
api
;
using
mindspore
::
dataset
::
DataType
;
using
mindspore
::
dataset
::
ShuffleMode
;
using
mindspore
::
dataset
::
Tensor
;
using
mindspore
::
dataset
::
Status
;
using
mindspore
::
dataset
::
Tensor
;
using
mindspore
::
dataset
::
Vocab
;
class
MindDataTestPipeline
:
public
UT
::
DatasetOpTesting
{
...
...
@@ -50,7 +51,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOp) {
EXPECT_EQ
(
s
,
Status
::
OK
());
// Create Lookup operation on ds
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
"<unk>"
);
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
"<unk>"
,
DataType
(
"int32"
)
);
EXPECT_NE
(
lookup
,
nullptr
);
// Create Map operation on ds
...
...
@@ -94,7 +95,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail1) {
// Create lookup op for ds
// Expected failure: "<unk>" is not a word of vocab
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
"<unk>"
);
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
"<unk>"
,
DataType
(
"int32"
)
);
EXPECT_EQ
(
lookup
,
nullptr
);
}
...
...
@@ -105,7 +106,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail2) {
// Create lookup op
// Expected failure: vocab is null
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
""
);
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
""
,
DataType
(
"int32"
)
);
EXPECT_EQ
(
lookup
,
nullptr
);
}
...
...
@@ -126,7 +127,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpWithEmptyUnknownToken) {
// Create Lookup operation on ds
// Expected failure: "" is not a word of vocab
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
""
);
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
""
,
DataType
(
"int32"
)
);
EXPECT_EQ
(
lookup
,
nullptr
);
}
...
...
@@ -148,7 +149,7 @@ TEST_F(MindDataTestPipeline, TestVocabFromDataset) {
EXPECT_EQ
(
home_index
,
4
);
// Create Lookup operation on ds
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
"<unk>"
);
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
"<unk>"
,
DataType
(
"int32"
)
);
EXPECT_NE
(
lookup
,
nullptr
);
// Create Map operation on ds
...
...
@@ -212,12 +213,15 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetDefault) {
uint64_t
i
=
0
;
std
::
vector
<
int32_t
>
expected
=
{
2
,
3
,
1
,
4
,
5
,
0
};
std
::
vector
<
int64_t
>
not_expected
=
{
2
,
3
,
1
,
4
,
5
,
0
};
while
(
row
.
size
()
!=
0
)
{
auto
ind
=
row
[
"text"
];
MS_LOG
(
INFO
)
<<
ind
->
shape
()
<<
" "
<<
*
ind
;
std
::
shared_ptr
<
Tensor
>
expected_item
;
std
::
shared_ptr
<
Tensor
>
expected_item
,
not_expected_item
;
Tensor
::
CreateScalar
(
expected
[
i
],
&
expected_item
);
Tensor
::
CreateScalar
(
not_expected
[
i
],
&
not_expected_item
);
EXPECT_EQ
(
*
ind
,
*
expected_item
);
EXPECT_NE
(
*
ind
,
*
not_expected_item
);
iter
->
GetNextRow
(
&
row
);
i
++
;
}
...
...
@@ -233,8 +237,8 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail1) {
// Create vocab from dataset
// Expected failure: top_k can not be negative
std
::
shared_ptr
<
Vocab
>
vocab
=
ds
->
BuildVocab
({
"text"
},
{
0
,
std
::
numeric_limits
<
int64_t
>::
max
()},
-
2
,
{
"<pad>"
,
"<unk>"
},
true
);
std
::
shared_ptr
<
Vocab
>
vocab
=
ds
->
BuildVocab
({
"text"
},
{
0
,
std
::
numeric_limits
<
int64_t
>::
max
()},
-
2
,
{
"<pad>"
,
"<unk>"
},
true
);
EXPECT_EQ
(
vocab
,
nullptr
);
}
...
...
@@ -247,9 +251,9 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail2) {
EXPECT_NE
(
ds
,
nullptr
);
// Create vocab from dataset
// Expected failure: requency_range [a,b] should be 0 <= a <= b
std
::
shared_ptr
<
Vocab
>
vocab
=
ds
->
BuildVocab
({
"text"
},
{
4
,
1
},
std
::
numeric_limits
<
int64_t
>::
max
(),
{
"<pad>"
,
"<unk>"
},
true
);
// Expected failure:
f
requency_range [a,b] should be 0 <= a <= b
std
::
shared_ptr
<
Vocab
>
vocab
=
ds
->
BuildVocab
({
"text"
},
{
4
,
1
},
std
::
numeric_limits
<
int64_t
>::
max
(),
{
"<pad>"
,
"<unk>"
},
true
);
EXPECT_EQ
(
vocab
,
nullptr
);
}
...
...
@@ -266,3 +270,52 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) {
std
::
shared_ptr
<
Vocab
>
vocab
=
ds
->
BuildVocab
({
"ColumnNotExist"
});
EXPECT_EQ
(
vocab
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestVocabFromDatasetInt64
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestVocabFromDatasetInt64."
;
// Create a TextFile dataset
std
::
string
data_file
=
datasets_root_path_
+
"/testVocab/words.txt"
;
std
::
shared_ptr
<
Dataset
>
ds
=
TextFile
({
data_file
},
0
,
ShuffleMode
::
kFalse
);
EXPECT_NE
(
ds
,
nullptr
);
// Create vocab from dataset
std
::
shared_ptr
<
Vocab
>
vocab
=
ds
->
BuildVocab
();
EXPECT_NE
(
vocab
,
nullptr
);
// Check if vocab has words or not
int32_t
home_index
=
vocab
->
Lookup
(
"home"
);
EXPECT_EQ
(
home_index
,
2
);
// Create Lookup operation on ds
std
::
shared_ptr
<
TensorOperation
>
lookup
=
text
::
Lookup
(
vocab
,
"home"
,
DataType
(
"int64"
));
EXPECT_NE
(
lookup
,
nullptr
);
// Create Map operation on ds
ds
=
ds
->
Map
({
lookup
});
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
uint64_t
i
=
0
;
std
::
vector
<
int64_t
>
expected
=
{
2
,
3
,
1
,
4
,
5
,
0
};
std
::
vector
<
int8_t
>
not_expected
=
{
2
,
3
,
1
,
4
,
5
,
0
};
while
(
row
.
size
()
!=
0
)
{
auto
ind
=
row
[
"text"
];
MS_LOG
(
INFO
)
<<
ind
->
shape
()
<<
" "
<<
*
ind
;
std
::
shared_ptr
<
Tensor
>
expected_item
,
not_expected_item
;
Tensor
::
CreateScalar
(
expected
[
i
],
&
expected_item
);
Tensor
::
CreateScalar
(
not_expected
[
i
],
&
not_expected_item
);
EXPECT_EQ
(
*
ind
,
*
expected_item
);
EXPECT_NE
(
*
ind
,
*
not_expected_item
);
iter
->
GetNextRow
(
&
row
);
i
++
;
}
}
\ No newline at end of file
tests/ut/python/dataset/test_vocab.py
浏览文件 @
bde9f18f
...
...
@@ -17,6 +17,7 @@ import numpy as np
import
mindspore.dataset
as
ds
import
mindspore.dataset.text
as
text
import
mindspore.common.dtype
as
mstype
# this file contains "home is behind the world head" each word is 1 line
DATA_FILE
=
"../data/dataset/testVocab/words.txt"
...
...
@@ -137,6 +138,36 @@ def test_from_file():
assert
"Input vocab_size must be greater than 0"
in
test_config
(
"w1 w2"
,
0
,
[],
True
)
assert
"Input vocab_size must be greater than 0"
in
test_config
(
"w1 w2"
,
-
1
,
[],
True
)
def
test_lookup_cast_type
():
def
gen
(
texts
):
for
word
in
texts
.
split
(
" "
):
yield
(
np
.
array
(
word
,
dtype
=
'S'
),)
def
test_config
(
lookup_str
,
data_type
=
None
):
try
:
vocab
=
text
.
Vocab
.
from_list
([
"w1"
,
"w2"
,
"w3"
],
special_tokens
=
[
"<unk>"
],
special_first
=
True
)
data
=
ds
.
GeneratorDataset
(
gen
(
lookup_str
),
column_names
=
[
"text"
])
# if data_type is None, test the default value of data_type
op
=
text
.
Lookup
(
vocab
,
"<unk>"
)
if
data_type
is
None
else
text
.
Lookup
(
vocab
,
"<unk>"
,
data_type
)
data
=
data
.
map
(
input_columns
=
[
"text"
],
operations
=
op
)
res
=
[]
for
d
in
data
.
create_dict_iterator
(
num_epochs
=
1
):
res
.
append
(
d
[
"text"
])
return
res
[
0
].
dtype
except
(
ValueError
,
RuntimeError
,
TypeError
)
as
e
:
return
str
(
e
)
# test result is correct
assert
test_config
(
"w1"
,
mstype
.
int8
)
==
np
.
dtype
(
"int8"
)
assert
test_config
(
"w2"
,
mstype
.
int32
)
==
np
.
dtype
(
"int32"
)
assert
test_config
(
"w3"
,
mstype
.
int64
)
==
np
.
dtype
(
"int64"
)
assert
test_config
(
"unk"
,
mstype
.
float32
)
!=
np
.
dtype
(
"int32"
)
assert
test_config
(
"unk"
)
==
np
.
dtype
(
"int32"
)
# test exception, data_type isn't the correct type
assert
"tldr is not of type (<class 'mindspore._c_expression.typing.Type'>,)"
in
test_config
(
"unk"
,
"tldr"
)
if
__name__
==
'__main__'
:
test_from_dict_exception
()
test_from_list_tutorial
()
...
...
@@ -144,3 +175,4 @@ if __name__ == '__main__':
test_from_dict_tutorial
()
test_from_list
()
test_from_file
()
test_lookup_cast_type
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录