提交 4136892a 编写于 作者: Y YangLuo

add SlidingWindow Op

上级 d89cedb9
...@@ -77,6 +77,7 @@ ...@@ -77,6 +77,7 @@
#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
#include "minddata/dataset/text/kernels/lookup_op.h" #include "minddata/dataset/text/kernels/lookup_op.h"
#include "minddata/dataset/text/kernels/ngram_op.h" #include "minddata/dataset/text/kernels/ngram_op.h"
#include "minddata/dataset/text/kernels/sliding_window_op.h"
#include "minddata/dataset/text/kernels/to_number_op.h" #include "minddata/dataset/text/kernels/to_number_op.h"
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
...@@ -640,6 +641,9 @@ void bindTokenizerOps(py::module *m) { ...@@ -640,6 +641,9 @@ void bindTokenizerOps(py::module *m) {
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
(void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>(
*m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.")
.def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis"));
} }
void bindDependIcuTokenizerOps(py::module *m) { void bindDependIcuTokenizerOps(py::module *m) {
......
...@@ -120,6 +120,7 @@ constexpr char kCaseFoldOp[] = "CaseFoldOp"; ...@@ -120,6 +120,7 @@ constexpr char kCaseFoldOp[] = "CaseFoldOp";
constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp";
constexpr char kLookupOp[] = "LookupOp"; constexpr char kLookupOp[] = "LookupOp";
constexpr char kNgramOp[] = "NgramOp"; constexpr char kNgramOp[] = "NgramOp";
constexpr char kSlidingWindowOp[] = "SlidingWindowOp";
constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op";
constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; constexpr char kRegexReplaceOp[] = "RegexReplaceOp";
constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp";
......
...@@ -12,10 +12,12 @@ if (NOT (CMAKE_SYSTEM_NAME MATCHES "Windows")) ...@@ -12,10 +12,12 @@ if (NOT (CMAKE_SYSTEM_NAME MATCHES "Windows"))
whitespace_tokenizer_op.cc) whitespace_tokenizer_op.cc)
endif() endif()
add_library(text-kernels OBJECT add_library(text-kernels OBJECT
data_utils.cc
lookup_op.cc lookup_op.cc
jieba_tokenizer_op.cc jieba_tokenizer_op.cc
unicode_char_tokenizer_op.cc unicode_char_tokenizer_op.cc
ngram_op.cc ngram_op.cc
sliding_window_op.cc
wordpiece_tokenizer_op.cc wordpiece_tokenizer_op.cc
truncate_sequence_pair_op.cc truncate_sequence_pair_op.cc
to_number_op.cc to_number_op.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/text/kernels/data_utils.h"
#include <algorithm>
#include <limits>
#include <string>
#include <vector>
#include "minddata/dataset/core/pybind_support.h"
#include "minddata/dataset/kernels/data/type_cast_op.h"
#include "minddata/dataset/kernels/data/slice_op.h"
#include "minddata/dataset/kernels/data/concatenate_op.h"
namespace mindspore {
namespace dataset {
Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape,
uint32_t width, int32_t axis) {
// if the data row has fewer items than width, the corresponding result row will be empty
if (out_shape.Size() == 0) {
MS_LOG(WARNING) << "The data row has fewer items than width, the result will be empty.";
if (input->type().value() == DataType::DE_STRING) {
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, std::vector<std::string>{}, TensorShape({0})));
} else {
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, TensorShape({0}), input->type()));
}
return Status::OK();
}
axis = Tensor::HandleNeg(axis, input->shape().Size());
int32_t axis_end = input->shape()[axis];
std::shared_ptr<Tensor> tmp;
auto concatenate_op = std::make_unique<ConcatenateOp>(axis, nullptr, nullptr);
// Slice on specified axis and concatenate on new axis
for (int32_t i = 0; i + width <= axis_end; i++) {
auto slice_op = std::make_unique<SliceOp>(Slice(i, i + width, 1));
slice_op->Compute(input, &tmp);
if (i == 0) {
*output = tmp;
} else {
TensorRow in({*output, tmp});
TensorRow out_row;
concatenate_op->Compute(in, &out_row);
*output = out_row[0];
}
}
(*output)->Reshape(out_shape);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_TEXT_DATA_UTILS_H_
#define DATASET_KERNELS_TEXT_DATA_UTILS_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/core/tensor_row.h"
namespace mindspore {
namespace dataset {
/// \brief Helper method that perform sliding window on input tensor.
/// \param[in] input - Input tensor.
/// \param[in] out_shape - Output shape of output tensor.
/// \param[in] width - The axis along which sliding window is computed.
/// \param[in] axis - The width of the window.
/// \param[out] output - Output tensor
/// \return Status return code
Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape,
uint32_t width, int32_t axis);
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_TEXT_DATA_UTILS_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/text/kernels/sliding_window_op.h"
namespace mindspore {
namespace dataset {
Status SlidingWindowOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SlidingWindosOp supports 1D Tensors only for now.");
CHECK_FAIL_RETURN_UNEXPECTED(axis_ == 0 || axis_ == -1, "axis supports 0 or -1 only for now.");
std::vector<TensorShape> input_shape = {input->shape()};
std::vector<TensorShape> output_shape = {TensorShape({})};
RETURN_IF_NOT_OK(OutputShape(input_shape, output_shape));
RETURN_IF_NOT_OK(SlidingWindowHelper(input, output, output_shape[0], width_, axis_));
return Status::OK();
}
Status SlidingWindowOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n");
int32_t axis = Tensor::HandleNeg(axis_, inputs[0].Size());
TensorShape input_shape = inputs[0];
std::vector<dsize_t> output_shape_initializer;
// if a data row has fewer items than width, the corresponding result row will be empty.
if (input_shape[axis] >= width_) {
for (int32_t idx = 0; idx < input_shape.Size(); ++idx) {
if (idx != axis) {
output_shape_initializer.push_back(input_shape[idx]);
} else {
output_shape_initializer.push_back(input_shape[idx] - (width_ - 1));
output_shape_initializer.push_back(width_);
}
}
}
outputs.pop_back();
outputs.emplace_back(TensorShape(output_shape_initializer));
CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n");
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_
#define DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/data_utils.h"
namespace mindspore {
namespace dataset {
class SlidingWindowOp : public TensorOp {
public:
/// \brief Constructor of SlidingWindowOp.
/// \param[in] width - The axis along which sliding window is computed.
/// \param[in] axis - The width of the window.
/// \return Status return code
explicit SlidingWindowOp(uint32_t width, int32_t axis = 0) : width_(width), axis_(axis) {}
/// \brief Destructor of SlidingWindowOp.
~SlidingWindowOp() override = default;
/// \brief Perform sliding window to tensor.
/// \param[in] input - Input tensor of Op.
/// \param[out] output - output tensor of Op.
/// \return Status return code
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
/// \brief Calculate tensor shape for output tensor.
/// \param[in] inputs - Input tensor shapes.
/// \param[out] outputs - Output tensor shapes.
/// \return Status return code
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
/// \brief Print args for debugging.
/// \param[in] out - std::ostream &out.
void Print(std::ostream &out) const override { out << "SliceWindowOp"; }
/// \brief Print name of op.
std::string Name() const override { return kSlidingWindowOp; }
private:
uint32_t width_; // The width of the window. Must be an integer and greater than zero.
int32_t axis_; // The axis along which sliding window is computed, only support 0/-1 for now.
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_TEXT_SLIDING_WINDOW_OP_H_
...@@ -19,13 +19,13 @@ utils provides some general methods for nlp text processing. ...@@ -19,13 +19,13 @@ utils provides some general methods for nlp text processing.
""" """
import platform import platform
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \ from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram, WordpieceTokenizer, TruncateSequencePair, \
ToNumber ToNumber, SlidingWindow
from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm from .utils import to_str, to_bytes, JiebaMode, Vocab, NormalizeForm
__all__ = [ __all__ = [
"Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram", "Lookup", "JiebaTokenizer", "UnicodeCharTokenizer", "Ngram",
"to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber", "to_str", "to_bytes", "Vocab", "WordpieceTokenizer", "TruncateSequencePair", "ToNumber",
"PythonTokenizer" "PythonTokenizer", "SlidingWindow"
] ]
if platform.system().lower() != 'windows': if platform.system().lower() != 'windows':
......
...@@ -54,7 +54,7 @@ from .utils import JiebaMode, NormalizeForm, to_str ...@@ -54,7 +54,7 @@ from .utils import JiebaMode, NormalizeForm, to_str
from .validators import check_lookup, check_jieba_add_dict, \ from .validators import check_lookup, check_jieba_add_dict, \
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\ check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\
check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\ check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\
check_to_number, check_bert_tokenizer, check_python_tokenizer check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow
from ..core.datatypes import mstype_to_detype from ..core.datatypes import mstype_to_detype
...@@ -72,6 +72,34 @@ class Lookup(cde.LookupOp): ...@@ -72,6 +72,34 @@ class Lookup(cde.LookupOp):
def __init__(self, vocab, unknown_token=None): def __init__(self, vocab, unknown_token=None):
super().__init__(vocab, unknown_token) super().__init__(vocab, unknown_token)
class SlidingWindow(cde.SlidingWindowOp):
"""
TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis
is a slice of data starting at the corresponding position, with a specified width.
Args:
width (int): The width of the window. Must be an integer and greater than zero.
axis (int, optional): The axis along which sliding window is computed (default=0).
Examples:
>>> # Data before
>>> # | col1 |
>>> # +-------------+
>>> # | [1,2,3,4,5] |
>>> # +-------------+
>>> data = data.map(operations=SlidingWindow(3, 0))
>>> # Data after
>>> # | col1 |
>>> # +-------------+
>>> # | [[1,2,3], |
>>> # | [2,3,4], |
>>> # | [3,4,5]] |
>>> # +--------------+
"""
@check_slidingwindow
def __init__(self, width, axis=0):
super().__init__(width=width, axis=axis)
class Ngram(cde.NgramOp): class Ngram(cde.NgramOp):
""" """
......
...@@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde ...@@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde
from mindspore._c_expression import typing from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
INT32_MAX, check_value, check_positive INT32_MAX, check_value, check_positive, check_pos_int32
def check_unique_list_of_words(words, arg_name): def check_unique_list_of_words(words, arg_name):
...@@ -328,6 +328,17 @@ def check_from_dataset(method): ...@@ -328,6 +328,17 @@ def check_from_dataset(method):
return new_method return new_method
def check_slidingwindow(method):
"""A wrapper that wrap a parameter checker to the original function(sliding window operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[width, axis], _ = parse_user_args(method, *args, **kwargs)
check_pos_int32(width, "width")
type_check(axis, (int,), "axis")
return method(self, *args, **kwargs)
return new_method
def check_ngram(method): def check_ngram(method):
"""A wrapper that wraps a parameter checker to the original function.""" """A wrapper that wraps a parameter checker to the original function."""
......
...@@ -92,6 +92,7 @@ SET(DE_UT_SRCS ...@@ -92,6 +92,7 @@ SET(DE_UT_SRCS
perf_data_test.cc perf_data_test.cc
c_api_test.cc c_api_test.cc
tensor_op_fusion_pass_test.cc tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc
) )
add_executable(de_ut_tests ${DE_UT_SRCS}) add_executable(de_ut_tests ${DE_UT_SRCS})
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/text/kernels/sliding_window_op.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestSlidingWindowOp : public UT::Common {
protected:
MindDataTestSlidingWindowOp() {}
};
TEST_F(MindDataTestSlidingWindowOp, Compute) {
MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->Compute.";
std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"};
TensorShape shape({static_cast<dsize_t>(strings.size())});
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
std::shared_ptr<Tensor> output;
std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0));
Status s = op->Compute(input, &output);
std::vector<std::string> out = {"one", "two", "three", "two", "three", "four", "three", "four", "five",
"four", "five", "six", "five", "six", "seven", "six", "seven", "eight"};
std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(out, TensorShape({6, 3}));
ASSERT_TRUE(output->shape() == expected->shape());
ASSERT_TRUE(output->type() == expected->type());
MS_LOG(DEBUG) << *output << std::endl;
MS_LOG(DEBUG) << *expected << std::endl;
ASSERT_TRUE(*output == *expected);
MS_LOG(INFO) << "MindDataTestSlidingWindowOp end.";
}
TEST_F(MindDataTestSlidingWindowOp, OutputShape) {
MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->OutputShape.";
std::vector<std::string> strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"};
TensorShape shape({static_cast<dsize_t>(strings.size())});
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
std::vector<TensorShape> input_shape = {input->shape()};
std::vector<TensorShape> output_shape = {TensorShape({})};
std::unique_ptr<SlidingWindowOp> op(new SlidingWindowOp(3, 0));
Status s = op->OutputShape(input_shape, output_shape);
MS_LOG(DEBUG) << "input_shape" << input_shape[0];
MS_LOG(DEBUG) << "output_shape" << output_shape[0];
ASSERT_TRUE(output_shape[0] == TensorShape({6, 3}));
MS_LOG(INFO) << "MindDataTestSlidingWindowOp end.";
}
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing SlidingWindow in mindspore.dataset
"""
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.text as text
def test_sliding_window_string():
""" test sliding_window with string type"""
inputs = [["大", "家", "早", "上", "好"]]
expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']])
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0))
result = []
for data in dataset.create_dict_iterator():
for i in range(data['text'].shape[0]):
result.append([])
for j in range(data['text'].shape[1]):
result[i].append(data['text'][i][j].decode('utf8'))
result = np.array(result)
np.testing.assert_array_equal(result, expect)
def test_sliding_window_number():
inputs = [1]
expect = np.array([[1]])
def gen(nums):
yield (np.array(nums),)
dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"])
dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(1, -1))
for data in dataset.create_dict_iterator():
np.testing.assert_array_equal(data['number'], expect)
def test_sliding_window_big_width():
inputs = [[1, 2, 3, 4, 5]]
expect = np.array([])
dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False)
dataset = dataset.map(input_columns=["number"], operations=text.SlidingWindow(30, 0))
for data in dataset.create_dict_iterator():
np.testing.assert_array_equal(data['number'], expect)
def test_sliding_window_exception():
try:
_ = text.SlidingWindow(0, 0)
assert False
except ValueError:
pass
try:
_ = text.SlidingWindow("1", 0)
assert False
except TypeError:
pass
try:
_ = text.SlidingWindow(1, "0")
assert False
except TypeError:
pass
try:
inputs = [[1, 2, 3, 4, 5]]
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(3, -100))
for _ in dataset.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "axis supports 0 or -1 only for now." in str(e)
try:
inputs = ["aa", "bb", "cc"]
dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
dataset = dataset.map(input_columns=["text"], operations=text.SlidingWindow(2, 0))
for _ in dataset.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "SlidingWindosOp supports 1D Tensors only for now." in str(e)
if __name__ == '__main__':
test_sliding_window_string()
test_sliding_window_number()
test_sliding_window_big_width()
test_sliding_window_exception()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册