Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5850b991
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5850b991
编写于
6月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2211 Add TruncatePair Op to dataset
Merge pull request !2211 from h.farahat/pair_truncate
上级
fce37a5f
b9495a9c
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
296 addition
and
10 deletion
+296
-10
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+7
-2
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
+2
-1
mindspore/ccsrc/dataset/kernels/data/mask_op.cc
mindspore/ccsrc/dataset/kernels/data/mask_op.cc
+1
-1
mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt
mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc
...e/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc
+66
-0
mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h
...re/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h
+48
-0
mindspore/dataset/text/__init__.py
mindspore/dataset/text/__init__.py
+2
-2
mindspore/dataset/text/transforms.py
mindspore/dataset/text/transforms.py
+29
-1
mindspore/dataset/text/validators.py
mindspore/dataset/text/validators.py
+20
-1
mindspore/dataset/transforms/validators.py
mindspore/dataset/transforms/validators.py
+1
-1
tests/ut/cpp/dataset/mask_test.cc
tests/ut/cpp/dataset/mask_test.cc
+1
-1
tests/ut/cpp/dataset/trucate_pair_test.cc
tests/ut/cpp/dataset/trucate_pair_test.cc
+51
-0
tests/ut/python/dataset/test_pair_truncate.py
tests/ut/python/dataset/test_pair_truncate.py
+67
-0
未找到文件。
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
5850b991
...
...
@@ -40,6 +40,7 @@
#include "dataset/kernels/data/fill_op.h"
#include "dataset/kernels/data/mask_op.h"
#include "dataset/kernels/data/slice_op.h"
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
...
...
@@ -384,7 +385,7 @@ void bindTensorOps2(py::module *m) {
*
m
,
"FillOp"
,
"Tensor operation to return tensor filled with same value as input fill value."
)
.
def
(
py
::
init
<
std
::
shared_ptr
<
Tensor
>>
());
(
void
)
py
::
class_
<
SliceOp
,
TensorOp
,
std
::
shared_ptr
<
SliceOp
>>
(
*
m
,
"SliceOp"
,
"Tensor
S
lice operation."
)
(
void
)
py
::
class_
<
SliceOp
,
TensorOp
,
std
::
shared_ptr
<
SliceOp
>>
(
*
m
,
"SliceOp"
,
"Tensor
s
lice operation."
)
.
def
(
py
::
init
<
bool
>
())
.
def
(
py
::
init
([](
const
py
::
list
&
py_list
)
{
std
::
vector
<
dsize_t
>
c_list
;
...
...
@@ -425,9 +426,13 @@ void bindTensorOps2(py::module *m) {
.
export_values
();
(
void
)
py
::
class_
<
MaskOp
,
TensorOp
,
std
::
shared_ptr
<
MaskOp
>>
(
*
m
,
"MaskOp"
,
"Tensor
operation mask
using relational comparator"
)
"Tensor
mask operation
using relational comparator"
)
.
def
(
py
::
init
<
RelationalOp
,
std
::
shared_ptr
<
Tensor
>
,
DataType
>
());
(
void
)
py
::
class_
<
TruncateSequencePairOp
,
TensorOp
,
std
::
shared_ptr
<
TruncateSequencePairOp
>>
(
*
m
,
"TruncateSequencePairOp"
,
"Tensor operation to truncate two tensors to a max_length"
)
.
def
(
py
::
init
<
int64_t
>
());
(
void
)
py
::
class_
<
RandomRotationOp
,
TensorOp
,
std
::
shared_ptr
<
RandomRotationOp
>>
(
*
m
,
"RandomRotationOp"
,
"Tensor operation to apply RandomRotation."
...
...
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
浏览文件 @
5850b991
...
...
@@ -7,4 +7,5 @@ add_library(kernels-data OBJECT
to_float16_op.cc
fill_op.cc
slice_op.cc
mask_op.cc
)
mask_op.cc
)
mindspore/ccsrc/dataset/kernels/data/mask_op.cc
浏览文件 @
5850b991
...
...
@@ -33,7 +33,7 @@ Status MaskOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Ten
if
(
type_
!=
DataType
::
DE_BOOL
)
{
RETURN_IF_NOT_OK
(
cast_
->
Compute
(
temp_output
,
output
));
}
else
{
*
output
=
temp_output
;
*
output
=
std
::
move
(
temp_output
)
;
}
return
Status
::
OK
();
...
...
mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt
浏览文件 @
5850b991
...
...
@@ -17,5 +17,6 @@ add_library(text-kernels OBJECT
unicode_char_tokenizer_op.cc
ngram_op.cc
wordpiece_tokenizer_op.cc
truncate_sequence_pair_op.cc
${
ICU_DEPEND_FILES
}
)
mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc
0 → 100644
浏览文件 @
5850b991
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/text/kernels/truncate_sequence_pair_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/kernels/data/slice_op.h"
namespace
mindspore
{
namespace
dataset
{
Status
TruncateSequencePairOp
::
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
{
IO_CHECK_VECTOR
(
input
,
output
);
CHECK_FAIL_RETURN_UNEXPECTED
(
input
.
size
()
==
2
,
"Number of inputs should be two."
);
std
::
shared_ptr
<
Tensor
>
seq1
=
input
[
0
];
std
::
shared_ptr
<
Tensor
>
seq2
=
input
[
1
];
CHECK_FAIL_RETURN_UNEXPECTED
(
seq1
->
shape
().
Rank
()
==
1
&&
seq2
->
shape
().
Rank
()
==
1
,
"Both sequences should be of rank 1"
);
dsize_t
length1
=
seq1
->
shape
()[
0
];
dsize_t
length2
=
seq2
->
shape
()[
0
];
dsize_t
outLength1
=
length1
;
dsize_t
outLength2
=
length2
;
dsize_t
total
=
length1
+
length2
;
while
(
total
>
max_length_
)
{
if
(
outLength1
>
outLength2
)
outLength1
--
;
else
outLength2
--
;
total
--
;
}
std
::
shared_ptr
<
Tensor
>
outSeq1
;
if
(
length1
!=
outLength1
)
{
std
::
unique_ptr
<
SliceOp
>
slice1
(
new
SliceOp
(
Slice
(
outLength1
-
length1
)));
RETURN_IF_NOT_OK
(
slice1
->
Compute
(
seq1
,
&
outSeq1
));
}
else
{
outSeq1
=
std
::
move
(
seq1
);
}
std
::
shared_ptr
<
Tensor
>
outSeq2
;
if
(
length2
!=
outLength2
)
{
std
::
unique_ptr
<
SliceOp
>
slice2
(
new
SliceOp
(
Slice
(
outLength2
-
length2
)));
RETURN_IF_NOT_OK
(
slice2
->
Compute
(
seq2
,
&
outSeq2
));
}
else
{
outSeq2
=
std
::
move
(
seq2
);
}
output
->
push_back
(
outSeq1
);
output
->
push_back
(
outSeq2
);
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h
0 → 100644
浏览文件 @
5850b991
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_
#define DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/data/data_utils.h"
namespace
mindspore
{
namespace
dataset
{
class
TruncateSequencePairOp
:
public
TensorOp
{
public:
explicit
TruncateSequencePairOp
(
dsize_t
length
)
:
max_length_
(
length
)
{}
~
TruncateSequencePairOp
()
override
=
default
;
void
Print
(
std
::
ostream
&
out
)
const
override
{
out
<<
"TruncateSequencePairOp"
;
}
Status
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
override
;
private:
dsize_t
max_length_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_
mindspore/dataset/text/__init__.py
浏览文件 @
5850b991
...
...
@@ -16,12 +16,12 @@
mindspore.dataset.text
"""
import
platform
from
.transforms
import
Lookup
,
JiebaTokenizer
,
UnicodeCharTokenizer
,
Ngram
,
WordpieceTokenizer
from
.transforms
import
Lookup
,
JiebaTokenizer
,
UnicodeCharTokenizer
,
Ngram
,
WordpieceTokenizer
,
TruncateSequencePair
from
.utils
import
to_str
,
to_bytes
,
JiebaMode
,
Vocab
,
NormalizeForm
__all__
=
[
"Lookup"
,
"JiebaTokenizer"
,
"UnicodeCharTokenizer"
,
"Ngram"
,
"to_str"
,
"to_bytes"
,
"JiebaMode"
,
"Vocab"
,
"WordpieceTokenizer"
"to_str"
,
"to_bytes"
,
"JiebaMode"
,
"Vocab"
,
"WordpieceTokenizer"
,
"TruncateSequencePair"
]
if
platform
.
system
().
lower
()
!=
'windows'
:
...
...
mindspore/dataset/text/transforms.py
浏览文件 @
5850b991
...
...
@@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde
from
.utils
import
JiebaMode
,
NormalizeForm
from
.validators
import
check_lookup
,
check_jieba_add_dict
,
\
check_jieba_add_word
,
check_jieba_init
,
check_ngram
check_jieba_add_word
,
check_jieba_init
,
check_ngram
,
check_pair_truncate
class
Lookup
(
cde
.
LookupOp
):
...
...
@@ -344,3 +344,31 @@ if platform.system().lower() != 'windows':
self
.
preserve_unused_token
=
preserve_unused_token
super
().
__init__
(
self
.
vocab
,
self
.
suffix_indicator
,
self
.
max_bytes_per_token
,
self
.
unknown_token
,
self
.
lower_case
,
self
.
keep_whitespace
,
self
.
normalization_form
,
self
.
preserve_unused_token
)
class
TruncateSequencePair
(
cde
.
TruncateSequencePairOp
):
"""
Truncate a pair of rank-1 tensors such that the total length is less than max_length.
This operation takes two input tensors and returns two output Tenors.
Args:
max_length(int): Maximum length required.
Examples:
>>> # Data before
>>> # | col1 | col2 |
>>> # +---------+---------|
>>> # | [1,2,3] | [4,5] |
>>> # +---------+---------+
>>> data = data.map(operations=TruncateSequencePair(4))
>>> # Data after
>>> # | col1 | col2 |
>>> # +---------+---------+
>>> # | [1,2] | [4,5] |
>>> # +---------+---------+
"""
@
check_pair_truncate
def
__init__
(
self
,
max_length
):
super
().
__init__
(
max_length
)
mindspore/dataset/text/validators.py
浏览文件 @
5850b991
...
...
@@ -20,7 +20,7 @@ from functools import wraps
import
mindspore._c_dataengine
as
cde
from
..transforms.validators
import
check_uint32
from
..transforms.validators
import
check_uint32
,
check_pos_int64
def
check_lookup
(
method
):
...
...
@@ -298,3 +298,22 @@ def check_ngram(method):
return
method
(
self
,
**
kwargs
)
return
new_method
def
check_pair_truncate
(
method
):
"""Wrapper method to check the parameters of number of pair truncate."""
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
max_length
=
(
list
(
args
)
+
[
None
])[
0
]
if
"max_length"
in
kwargs
:
max_length
=
kwargs
.
get
(
"max_length"
)
if
max_length
is
None
:
raise
ValueError
(
"max_length is not provided."
)
check_pos_int64
(
max_length
)
kwargs
[
"max_length"
]
=
max_length
return
method
(
self
,
**
kwargs
)
return
new_method
mindspore/dataset/transforms/validators.py
浏览文件 @
5850b991
...
...
@@ -216,7 +216,7 @@ def check_slice_op(method):
def
check_mask_op
(
method
):
"""Wrapper method to check the parameters of
slice
."""
"""Wrapper method to check the parameters of
mask
."""
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
...
...
tests/ut/cpp/dataset/mask_test.cc
浏览文件 @
5850b991
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
tests/ut/cpp/dataset/trucate_pair_test.cc
0 → 100644
浏览文件 @
5850b991
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "securec.h"
#include "dataset/core/tensor.h"
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
using
namespace
mindspore
::
dataset
;
namespace
py
=
pybind11
;
class
MindDataTestTruncatePairOp
:
public
UT
::
Common
{
public:
MindDataTestTruncatePairOp
()
{}
void
SetUp
()
{
GlobalInit
();
}
};
TEST_F
(
MindDataTestTruncatePairOp
,
Basics
)
{
std
::
shared_ptr
<
Tensor
>
t1
;
Tensor
::
CreateTensor
(
&
t1
,
std
::
vector
<
uint32_t
>
({
1
,
2
,
3
}));
std
::
shared_ptr
<
Tensor
>
t2
;
Tensor
::
CreateTensor
(
&
t2
,
std
::
vector
<
uint32_t
>
({
4
,
5
}));
TensorRow
in
({
t1
,
t2
});
std
::
shared_ptr
<
TruncateSequencePairOp
>
op
=
std
::
make_shared
<
TruncateSequencePairOp
>
(
4
);
TensorRow
out
;
ASSERT_TRUE
(
op
->
Compute
(
in
,
&
out
).
IsOk
());
std
::
shared_ptr
<
Tensor
>
out1
;
Tensor
::
CreateTensor
(
&
out1
,
std
::
vector
<
uint32_t
>
({
1
,
2
}));
std
::
shared_ptr
<
Tensor
>
out2
;
Tensor
::
CreateTensor
(
&
out2
,
std
::
vector
<
uint32_t
>
({
4
,
5
}));
ASSERT_EQ
(
*
out1
,
*
out
[
0
]);
ASSERT_EQ
(
*
out2
,
*
out
[
1
]);
}
tests/ut/python/dataset/test_pair_truncate.py
0 → 100644
浏览文件 @
5850b991
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Mask op in DE
"""
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
import
mindspore.dataset.text
as
text
def
compare
(
in1
,
in2
,
length
,
out1
,
out2
):
data
=
ds
.
NumpySlicesDataset
({
"s1"
:
[
in1
],
"s2"
:
[
in2
]})
data
=
data
.
map
(
input_columns
=
[
"s1"
,
"s2"
],
operations
=
text
.
TruncateSequencePair
(
length
))
for
d
in
data
.
create_dict_iterator
():
np
.
testing
.
assert_array_equal
(
out1
,
d
[
"s1"
])
np
.
testing
.
assert_array_equal
(
out2
,
d
[
"s2"
])
def
test_basics
():
compare
(
in1
=
[
1
,
2
,
3
],
in2
=
[
4
,
5
],
length
=
4
,
out1
=
[
1
,
2
],
out2
=
[
4
,
5
])
compare
(
in1
=
[
1
,
2
],
in2
=
[
4
,
5
],
length
=
4
,
out1
=
[
1
,
2
],
out2
=
[
4
,
5
])
compare
(
in1
=
[
1
],
in2
=
[
4
],
length
=
4
,
out1
=
[
1
],
out2
=
[
4
])
compare
(
in1
=
[
1
,
2
,
3
,
4
],
in2
=
[
5
],
length
=
4
,
out1
=
[
1
,
2
,
3
],
out2
=
[
5
])
compare
(
in1
=
[
1
,
2
,
3
,
4
],
in2
=
[
5
,
6
,
7
,
8
],
length
=
4
,
out1
=
[
1
,
2
],
out2
=
[
5
,
6
])
def
test_basics_odd
():
compare
(
in1
=
[
1
,
2
,
3
],
in2
=
[
4
,
5
],
length
=
3
,
out1
=
[
1
,
2
],
out2
=
[
4
])
compare
(
in1
=
[
1
,
2
],
in2
=
[
4
,
5
],
length
=
3
,
out1
=
[
1
,
2
],
out2
=
[
4
])
compare
(
in1
=
[
1
],
in2
=
[
4
],
length
=
5
,
out1
=
[
1
],
out2
=
[
4
])
compare
(
in1
=
[
1
,
2
,
3
,
4
],
in2
=
[
5
],
length
=
3
,
out1
=
[
1
,
2
],
out2
=
[
5
])
compare
(
in1
=
[
1
,
2
,
3
,
4
],
in2
=
[
5
,
6
,
7
,
8
],
length
=
3
,
out1
=
[
1
,
2
],
out2
=
[
5
])
def
test_basics_str
():
compare
(
in1
=
[
b
"1"
,
b
"2"
,
b
"3"
],
in2
=
[
4
,
5
],
length
=
4
,
out1
=
[
b
"1"
,
b
"2"
],
out2
=
[
4
,
5
])
compare
(
in1
=
[
b
"1"
,
b
"2"
],
in2
=
[
b
"4"
,
b
"5"
],
length
=
4
,
out1
=
[
b
"1"
,
b
"2"
],
out2
=
[
b
"4"
,
b
"5"
])
compare
(
in1
=
[
b
"1"
],
in2
=
[
4
],
length
=
4
,
out1
=
[
b
"1"
],
out2
=
[
4
])
compare
(
in1
=
[
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
],
in2
=
[
b
"5"
],
length
=
4
,
out1
=
[
b
"1"
,
b
"2"
,
b
"3"
],
out2
=
[
b
"5"
])
compare
(
in1
=
[
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
],
in2
=
[
5
,
6
,
7
,
8
],
length
=
4
,
out1
=
[
b
"1"
,
b
"2"
],
out2
=
[
5
,
6
])
def
test_exceptions
():
with
pytest
.
raises
(
RuntimeError
)
as
info
:
compare
(
in1
=
[
1
,
2
,
3
,
4
],
in2
=
[
5
,
6
,
7
,
8
],
length
=
1
,
out1
=
[
1
,
2
],
out2
=
[
5
])
assert
"Indices are empty, generated tensor would be empty"
in
str
(
info
.
value
)
if
__name__
==
"__main__"
:
test_basics
()
test_basics_odd
()
test_basics_str
()
test_exceptions
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录