diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index c0add5a896afc073f06b5391b63867ad8c023c5e..c105a5410444f951cd39e85b77e3c5db72612ca2 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -40,6 +40,7 @@ #include "dataset/kernels/data/fill_op.h" #include "dataset/kernels/data/mask_op.h" #include "dataset/kernels/data/slice_op.h" +#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" #include "dataset/kernels/data/type_cast_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/image_folder_op.h" @@ -384,7 +385,7 @@ void bindTensorOps2(py::module *m) { *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") .def(py::init>()); - (void)py::class_>(*m, "SliceOp", "Tensor Slice operation.") + (void)py::class_>(*m, "SliceOp", "Tensor slice operation.") .def(py::init()) .def(py::init([](const py::list &py_list) { std::vector c_list; @@ -425,9 +426,13 @@ void bindTensorOps2(py::module *m) { .export_values(); (void)py::class_>(*m, "MaskOp", - "Tensor operation mask using relational comparator") + "Tensor mask operation using relational comparator") .def(py::init, DataType>()); + (void)py::class_>( + *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") + .def(py::init()); + (void)py::class_>( *m, "RandomRotationOp", "Tensor operation to apply RandomRotation." diff --git a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt index 03457ca4f5077bde9f01f0734b4b463e808ada58..80620dd91a4f8a3d02763b318d5cf9a2808a76f9 100644 --- a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt @@ -7,4 +7,5 @@ add_library(kernels-data OBJECT to_float16_op.cc fill_op.cc slice_op.cc - mask_op.cc) + mask_op.cc + ) diff --git a/mindspore/ccsrc/dataset/kernels/data/mask_op.cc b/mindspore/ccsrc/dataset/kernels/data/mask_op.cc index ba98ab58921534d7fa8d02cbc368ebc97df3bbdb..2cfeb7e36fbfaed70dea2d202a3698c5edb0cd64 100644 --- a/mindspore/ccsrc/dataset/kernels/data/mask_op.cc +++ b/mindspore/ccsrc/dataset/kernels/data/mask_op.cc @@ -33,7 +33,7 @@ Status MaskOp::Compute(const std::shared_ptr &input, std::shared_ptrCompute(temp_output, output)); } else { - *output = temp_output; + *output = std::move(temp_output); } return Status::OK(); diff --git a/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt b/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt index 8c4d19ab2c4b3cd37090af5ede8d05546c77351a..396d03fe44c94e1812a46914d23a4bbb250447ea 100644 --- a/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt @@ -17,5 +17,6 @@ add_library(text-kernels OBJECT unicode_char_tokenizer_op.cc ngram_op.cc wordpiece_tokenizer_op.cc + truncate_sequence_pair_op.cc ${ICU_DEPEND_FILES} ) diff --git a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc b/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..136d5006dfe1ae657d3f75d15e2c98ef0e94f41b --- /dev/null +++ b/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dataset/text/kernels/truncate_sequence_pair_op.h" + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/kernels/data/slice_op.h" + +namespace mindspore { +namespace dataset { + +Status TruncateSequencePairOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 2, "Number of inputs should be two."); + std::shared_ptr seq1 = input[0]; + std::shared_ptr seq2 = input[1]; + CHECK_FAIL_RETURN_UNEXPECTED(seq1->shape().Rank() == 1 && seq2->shape().Rank() == 1, + "Both sequences should be of rank 1"); + dsize_t length1 = seq1->shape()[0]; + dsize_t length2 = seq2->shape()[0]; + dsize_t outLength1 = length1; + dsize_t outLength2 = length2; + + dsize_t total = length1 + length2; + while (total > max_length_) { + if (outLength1 > outLength2) + outLength1--; + else + outLength2--; + total--; + } + std::shared_ptr outSeq1; + if (length1 != outLength1) { + std::unique_ptr slice1(new SliceOp(Slice(outLength1 - length1))); + RETURN_IF_NOT_OK(slice1->Compute(seq1, &outSeq1)); + } else { + outSeq1 = std::move(seq1); + } + + std::shared_ptr outSeq2; + if (length2 != outLength2) { + std::unique_ptr slice2(new SliceOp(Slice(outLength2 - length2))); + RETURN_IF_NOT_OK(slice2->Compute(seq2, &outSeq2)); + } else { + outSeq2 = std::move(seq2); + } + output->push_back(outSeq1); + output->push_back(outSeq2); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h b/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e8be6802a8bffffe99c89d5e623596bf2fca728f --- /dev/null +++ b/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ +#define DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ + +#include +#include +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/kernels/data/type_cast_op.h" +#include "dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { + +class TruncateSequencePairOp : public TensorOp { + public: + explicit TruncateSequencePairOp(dsize_t length) : max_length_(length) {} + + ~TruncateSequencePairOp() override = default; + + void Print(std::ostream &out) const override { out << "TruncateSequencePairOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + private: + dsize_t max_length_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ diff --git a/mindspore/dataset/text/__init__.py b/mindspore/dataset/text/__init__.py index b98093d45ae430e9ab26630dee1e425025cabbd0..364ea75d574e7ca7c10a13c3e49a9fc9b7cc9ebb 100644 --- a/mindspore/dataset/text/__init__.py +++ b/mindspore/dataset/text/__init__.py @@ -16,12 +16,12 @@ mindspore.dataset.text """ import platform -from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer +from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm __all__ = [ "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", - "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer" + "to_str", "to_bytes", "JiebaMode", "Vocab", "WordpieceTokenizer", "TruncateSequencePair" ] if platform.system().lower() != 'windows': diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index c9cfd55999db6893e2eb2b7d4eaec22a75b88d7e..4a64ed3c42495ba45245c3c6eeb0efcff1f3eb12 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde from .utils import JiebaMode, NormalizeForm from .validators import check_lookup, check_jieba_add_dict, \ - check_jieba_add_word, check_jieba_init, check_ngram + check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate class Lookup(cde.LookupOp): @@ -344,3 +344,31 @@ if platform.system().lower() != 'windows': self.preserve_unused_token = preserve_unused_token super().__init__(self.vocab, self.suffix_indicator, self.max_bytes_per_token, self.unknown_token, self.lower_case, self.keep_whitespace, self.normalization_form, self.preserve_unused_token) + + +class TruncateSequencePair(cde.TruncateSequencePairOp): + """ + Truncate a pair of rank-1 tensors such that the total length is less than max_length. + + This operation takes two input tensors and returns two output Tenors. + + Args: + max_length(int): Maximum length required. + + Examples: + >>> # Data before + >>> # | col1 | col2 | + >>> # +---------+---------| + >>> # | [1,2,3] | [4,5] | + >>> # +---------+---------+ + >>> data = data.map(operations=TruncateSequencePair(4)) + >>> # Data after + >>> # | col1 | col2 | + >>> # +---------+---------+ + >>> # | [1,2] | [4,5] | + >>> # +---------+---------+ + """ + + @check_pair_truncate + def __init__(self, max_length): + super().__init__(max_length) diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index 25d0aaf2415f313676d0e3901af984535cf4806e..4004bf40a43e0bc0777e20c8792c02a8aee05803 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -20,7 +20,7 @@ from functools import wraps import mindspore._c_dataengine as cde -from ..transforms.validators import check_uint32 +from ..transforms.validators import check_uint32, check_pos_int64 def check_lookup(method): @@ -298,3 +298,22 @@ def check_ngram(method): return method(self, **kwargs) return new_method + + +def check_pair_truncate(method): + """Wrapper method to check the parameters of number of pair truncate.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + max_length = (list(args) + [None])[0] + if "max_length" in kwargs: + max_length = kwargs.get("max_length") + if max_length is None: + raise ValueError("max_length is not provided.") + + check_pos_int64(max_length) + kwargs["max_length"] = max_length + + return method(self, **kwargs) + + return new_method diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index 4033b573ca6ad5b07187ba55ebae2c76a3d9d9d3..cb9fd158a4b0270609af242c849480469192277e 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -216,7 +216,7 @@ def check_slice_op(method): def check_mask_op(method): - """Wrapper method to check the parameters of slice.""" + """Wrapper method to check the parameters of mask.""" @wraps(method) def new_method(self, *args, **kwargs): diff --git a/tests/ut/cpp/dataset/mask_test.cc b/tests/ut/cpp/dataset/mask_test.cc index b2220f2a3f339934ff121f99092a3badd7b0c137..eb8a49aa36e87fbb651ad8fe4fc781f27390fca9 100644 --- a/tests/ut/cpp/dataset/mask_test.cc +++ b/tests/ut/cpp/dataset/mask_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tests/ut/cpp/dataset/trucate_pair_test.cc b/tests/ut/cpp/dataset/trucate_pair_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..95e2aaa11b39b342e7b1c3d4c5f2a6c2fd79ad1a --- /dev/null +++ b/tests/ut/cpp/dataset/trucate_pair_test.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "dataset/core/client.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "securec.h" +#include "dataset/core/tensor.h" +#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" + +using namespace mindspore::dataset; + +namespace py = pybind11; + +class MindDataTestTruncatePairOp : public UT::Common { + public: + MindDataTestTruncatePairOp() {} + + void SetUp() { GlobalInit(); } +}; + +TEST_F(MindDataTestTruncatePairOp, Basics) { + std::shared_ptr t1; + Tensor::CreateTensor(&t1, std::vector({1, 2, 3})); + std::shared_ptr t2; + Tensor::CreateTensor(&t2, std::vector({4, 5})); + TensorRow in({t1, t2}); + std::shared_ptr op = std::make_shared(4); + TensorRow out; + ASSERT_TRUE(op->Compute(in, &out).IsOk()); + std::shared_ptr out1; + Tensor::CreateTensor(&out1, std::vector({1, 2})); + std::shared_ptr out2; + Tensor::CreateTensor(&out2, std::vector({4, 5})); + ASSERT_EQ(*out1, *out[0]); + ASSERT_EQ(*out2, *out[1]); +} diff --git a/tests/ut/python/dataset/test_pair_truncate.py b/tests/ut/python/dataset/test_pair_truncate.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1138e5a9ceaa7fe9a823a40acb07dc1d054ee0 --- /dev/null +++ b/tests/ut/python/dataset/test_pair_truncate.py @@ -0,0 +1,67 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing Mask op in DE +""" +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.text as text + + +def compare(in1, in2, length, out1, out2): + data = ds.NumpySlicesDataset({"s1": [in1], "s2": [in2]}) + data = data.map(input_columns=["s1", "s2"], operations=text.TruncateSequencePair(length)) + for d in data.create_dict_iterator(): + np.testing.assert_array_equal(out1, d["s1"]) + np.testing.assert_array_equal(out2, d["s2"]) + + +def test_basics(): + compare(in1=[1, 2, 3], in2=[4, 5], length=4, out1=[1, 2], out2=[4, 5]) + compare(in1=[1, 2], in2=[4, 5], length=4, out1=[1, 2], out2=[4, 5]) + compare(in1=[1], in2=[4], length=4, out1=[1], out2=[4]) + compare(in1=[1, 2, 3, 4], in2=[5], length=4, out1=[1, 2, 3], out2=[5]) + compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=4, out1=[1, 2], out2=[5, 6]) + + +def test_basics_odd(): + compare(in1=[1, 2, 3], in2=[4, 5], length=3, out1=[1, 2], out2=[4]) + compare(in1=[1, 2], in2=[4, 5], length=3, out1=[1, 2], out2=[4]) + compare(in1=[1], in2=[4], length=5, out1=[1], out2=[4]) + compare(in1=[1, 2, 3, 4], in2=[5], length=3, out1=[1, 2], out2=[5]) + compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=3, out1=[1, 2], out2=[5]) + + +def test_basics_str(): + compare(in1=[b"1", b"2", b"3"], in2=[4, 5], length=4, out1=[b"1", b"2"], out2=[4, 5]) + compare(in1=[b"1", b"2"], in2=[b"4", b"5"], length=4, out1=[b"1", b"2"], out2=[b"4", b"5"]) + compare(in1=[b"1"], in2=[4], length=4, out1=[b"1"], out2=[4]) + compare(in1=[b"1", b"2", b"3", b"4"], in2=[b"5"], length=4, out1=[b"1", b"2", b"3"], out2=[b"5"]) + compare(in1=[b"1", b"2", b"3", b"4"], in2=[5, 6, 7, 8], length=4, out1=[b"1", b"2"], out2=[5, 6]) + + +def test_exceptions(): + with pytest.raises(RuntimeError) as info: + compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=1, out1=[1, 2], out2=[5]) + assert "Indices are empty, generated tensor would be empty" in str(info.value) + + +if __name__ == "__main__": + test_basics() + test_basics_odd() + test_basics_str() + test_exceptions()