Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0a2980ca
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a2980ca
编写于
7月 25, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 25, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3440 [MD]Fix Segementation Falut when SentencepieceTokenizer Op before zipOp and ConcatOp
Merge pull request !3440 from xulei/sentence_piece0715
上级
0fac402a
0af6d757
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
59 addition
and
14 deletion
+59
-14
mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc
...ddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc
+18
-13
mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h
...nddata/dataset/text/kernels/sentence_piece_tokenizer_op.h
+1
-0
tests/ut/python/dataset/test_sentencepiece_tokenizer.py
tests/ut/python/dataset/test_sentencepiece_tokenizer.py
+40
-1
未找到文件。
mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc
浏览文件 @
0a2980ca
...
...
@@ -27,17 +27,34 @@ namespace dataset {
SentencePieceTokenizerOp
::
SentencePieceTokenizerOp
(
const
std
::
shared_ptr
<
SentencePieceVocab
>
vocab
,
const
SPieceTokenizerLoadType
load_type
,
const
SPieceTokenizerOutType
out_type
)
:
vocab_
(
vocab
),
load_type_
(
load_type
),
out_type_
(
out_type
)
{}
:
vocab_
(
vocab
),
load_type_
(
load_type
),
out_type_
(
out_type
)
{
auto
status
=
processor_
.
LoadFromSerializedProto
(
vocab_
.
get
()
->
model_proto
());
if
(
!
status
.
ok
())
{
model_status_
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"parser vocab model filed."
);
}
else
{
model_status_
=
Status
::
OK
();
}
}
SentencePieceTokenizerOp
::
SentencePieceTokenizerOp
(
const
std
::
string
&
model_path
,
const
std
::
string
&
model_filename
,
const
SPieceTokenizerLoadType
load_type
,
const
SPieceTokenizerOutType
out_type
)
:
load_type_
(
load_type
),
out_type_
(
out_type
)
{
(
void
)
GetModelRealPath
(
model_path
,
model_filename
);
auto
status
=
processor_
.
Load
(
file_path_
);
if
(
!
status
.
ok
())
{
model_status_
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"load vocab model filed."
);
}
else
{
model_status_
=
Status
::
OK
();
}
}
Status
SentencePieceTokenizerOp
::
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
{
IO_CHECK
(
input
,
output
);
if
(
!
model_status_
.
IsOk
())
{
return
model_status_
;
}
if
(
input
->
Rank
()
!=
0
||
input
->
type
()
!=
DataType
::
DE_STRING
)
{
RETURN_STATUS_UNEXPECTED
(
"the input tensor should be scalar string tensor"
);
}
...
...
@@ -45,18 +62,6 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s
std
::
string_view
sentence_v
;
RETURN_IF_NOT_OK
(
input
->
GetItemAt
(
&
sentence_v
,
{}));
std
::
string
sentence
{
sentence_v
};
if
(
load_type_
==
SPieceTokenizerLoadType
::
kFile
)
{
auto
status
=
processor_
.
Load
(
file_path_
);
if
(
!
status
.
ok
())
{
RETURN_STATUS_UNEXPECTED
(
"load sentence piece model failed."
);
}
}
else
{
RETURN_UNEXPECTED_IF_NULL
(
vocab_
);
auto
status
=
processor_
.
LoadFromSerializedProto
(
vocab_
.
get
()
->
model_proto
());
if
(
!
status
.
ok
())
{
RETURN_STATUS_UNEXPECTED
(
"sentence piece load model failed."
);
}
}
if
(
out_type_
==
SPieceTokenizerOutType
::
kString
)
{
std
::
vector
<
std
::
string
>
pieces
;
...
...
mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h
浏览文件 @
0a2980ca
...
...
@@ -58,6 +58,7 @@ class SentencePieceTokenizerOp : public TensorOp {
std
::
string
file_path_
;
SPieceTokenizerLoadType
load_type_
;
sentencepiece
::
SentencePieceProcessor
processor_
;
Status
model_status_
;
};
}
// namespace dataset
}
// namespace mindspore
...
...
tests/ut/python/dataset/test_sentencepiece_tokenizer.py
浏览文件 @
0a2980ca
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
copy
import
mindspore.dataset.text
as
text
import
mindspore.dataset
as
ds
from
mindspore.dataset.text
import
SentencePieceModel
,
to_str
,
SPieceTokenizerOutType
...
...
@@ -121,6 +121,44 @@ def test_build_from_dataset():
assert
value
==
expect
[
key
]
def
apply_func
(
dataset
):
input_columns
=
[
'text'
]
output_columns
=
[
'text2'
]
dataset
=
dataset
.
rename
(
input_columns
,
output_columns
)
return
dataset
def
zip_test
(
dataset
):
dataset_1
=
copy
.
deepcopy
(
dataset
)
dataset_2
=
copy
.
deepcopy
(
dataset
)
dataset_1
=
dataset_1
.
apply
(
apply_func
)
dataset_zip
=
ds
.
zip
((
dataset_1
,
dataset_2
))
expect
=
[
'▁I'
,
'▁sa'
,
'w'
,
'▁a'
,
'▁girl'
,
'▁with'
,
'▁a'
,
'▁te'
,
'les'
,
'co'
,
'pe'
,
'.'
]
for
i
in
dataset_zip
.
create_dict_iterator
():
ret
=
to_str
(
i
[
"text"
])
for
key
,
value
in
enumerate
(
ret
):
assert
value
==
expect
[
key
]
def
concat_test
(
dataset
):
dataset_1
=
copy
.
deepcopy
(
dataset
)
dataset
=
dataset
.
concat
(
dataset_1
)
expect
=
[
'▁I'
,
'▁sa'
,
'w'
,
'▁a'
,
'▁girl'
,
'▁with'
,
'▁a'
,
'▁te'
,
'les'
,
'co'
,
'pe'
,
'.'
]
for
i
in
dataset
.
create_dict_iterator
():
ret
=
to_str
(
i
[
"text"
])
for
key
,
value
in
enumerate
(
ret
):
assert
value
==
expect
[
key
]
def
test_with_zip_concat
():
data
=
ds
.
TextFileDataset
(
VOCAB_FILE
,
shuffle
=
False
)
vocab
=
text
.
SentencePieceVocab
.
from_dataset
(
data
,
[
""
],
5000
,
0.9995
,
SentencePieceModel
.
UNIGRAM
,
{})
tokenizer
=
text
.
SentencePieceTokenizer
(
vocab
,
out_type
=
SPieceTokenizerOutType
.
STRING
)
dataset
=
ds
.
TextFileDataset
(
DATA_FILE
,
shuffle
=
False
)
dataset
=
dataset
.
map
(
operations
=
tokenizer
,
num_parallel_workers
=
2
)
zip_test
(
dataset
)
concat_test
(
dataset
)
if
__name__
==
"__main__"
:
test_from_vocab_to_str_UNIGRAM
()
test_from_vocab_to_str_BPE
()
...
...
@@ -130,3 +168,4 @@ if __name__ == "__main__":
test_from_file_to_str
()
test_from_file_to_int
()
test_build_from_dataset
()
test_with_zip_concat
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录