提交 3adbde56 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2162 SliceOp

Merge pull request !2162 from h.farahat/slice_op
......@@ -37,8 +37,9 @@
#include "dataset/kernels/image/resize_bilinear_op.h"
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/kernels/data/slice_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
......@@ -369,6 +370,37 @@ void bindTensorOps2(py::module *m) {
*m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
.def(py::init<std::shared_ptr<Tensor>>());
(void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "")
.def(py::init<bool>())
.def(py::init([](const py::list &py_list) {
std::vector<dsize_t> c_list;
for (auto l : py_list) {
if (!l.is_none()) {
c_list.push_back(py::reinterpret_borrow<py::int_>(l));
}
}
return std::make_shared<SliceOp>(c_list);
}))
.def(py::init([](const py::tuple &py_slice) {
if (py_slice.size() != 3) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
}
Slice c_slice;
if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) {
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]),
py::reinterpret_borrow<py::int_>(py_slice[2]));
} else if (py_slice[0].is_none() && py_slice[2].is_none()) {
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[1]));
} else if (!py_slice[0].is_none() && !py_slice[1].is_none()) {
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]));
}
if (!c_slice.valid()) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
}
return std::make_shared<SliceOp>(c_slice);
}));
(void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
*m, "RandomRotationOp",
"Tensor operation to apply RandomRotation."
......
......@@ -916,6 +916,61 @@ Status Tensor::CopyLastDimAt(const std::shared_ptr<Tensor> &src, const std::vect
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error");
return Status::OK();
}
Status Tensor::Slice(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) {
CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only.");
CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty.");
if (type_.IsNumeric()) {
return SliceNumeric(out, indices);
} else {
return SliceString(out, indices);
}
}
Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) {
RETURN_IF_NOT_OK(
CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(indices.size())}), type_));
(*out)->GetMutableBuffer();
dsize_t out_index = 0;
dsize_t dim_length = shape_[0];
dsize_t type_size = type_.SizeInBytes();
dsize_t src_start = handleNeg(indices[0], dim_length);
uchar *dst_addr = (*out)->data_;
dsize_t count = 1;
for (dsize_t i = 0; i < indices.size(); i++) {
dsize_t cur_index = handleNeg(indices[i], dim_length);
CHECK_FAIL_RETURN_UNEXPECTED(
cur_index >= 0 && cur_index < dim_length,
"Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")");
if (i < indices.size() - 1) {
dsize_t next_index = handleNeg(indices[i + 1], dim_length);
if (next_index == cur_index + 1) {
count++;
continue;
}
}
memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size);
out_index += count;
if (i < indices.size() - 1) {
src_start = handleNeg(indices[i + 1], dim_length); // next index
}
count = 1;
}
return Status::OK();
}
Status Tensor::SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) {
dsize_t dim_length = shape_[0];
std::vector<std::string> strings;
for (dsize_t index : indices) {
dsize_t cur_index = handleNeg(index, dim_length);
CHECK_FAIL_RETURN_UNEXPECTED(
cur_index >= 0 && cur_index < dim_length,
"Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")");
std::string_view sv;
GetItemAt(&sv, {cur_index});
strings.emplace_back(sv);
}
return CreateTensor(out, strings);
}
} // namespace dataset
} // namespace mindspore
......@@ -347,6 +347,22 @@ class Tensor {
return ss.str();
}
// Handle negative indices.
static inline dsize_t handleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; }
// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported.
// Based on the type of tensor, SliceNumeric or SliceString will be called
// @param out Tensor
// @param indices vector of indices
// @return Status error code
Status Slice(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices);
// Slice numeric tensors.
Status SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices);
// Slice string tensors
Status SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices);
// Constructs numpy array from input tensor
// @param data this data is the location of python data
// @return Status code
......
......@@ -5,4 +5,5 @@ add_library(kernels-data OBJECT
one_hot_op.cc
type_cast_op.cc
to_float16_op.cc
fill_op.cc)
fill_op.cc
slice_op.cc)
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/data/slice_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/data/data_utils.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
Status SliceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SliceOp supports 1D Tensors only for now.");
// if `all` flag is true, output is just the input.
if (all_) {
*output = input;
return Status::OK();
}
// if slice object was provided, indices should be empty. Generate indices from the slice object.
if (slice_.valid() && indices_.empty()) {
dsize_t len = input->shape()[0];
indices_ = slice_.Indices(len);
return input->Slice(output, indices_);
}
// if indices are not empty, slices should be invalid, use indices_ to slice
if (!indices_.empty() && !slice_.valid()) {
return input->Slice(output, indices_);
}
RETURN_STATUS_UNEXPECTED("The indexing parameters are invalid");
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_DATA_SLICE_OP_H_
#define DATASET_KERNELS_DATA_SLICE_OP_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
class Slice {
public:
Slice() : start_(0), stop_(0), step_(0) {}
Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {}
Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {}
explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {}
std::vector<dsize_t> Indices(dsize_t length) {
std::vector<dsize_t> indices;
dsize_t index = std::min(Tensor::handleNeg(start_, length), length);
dsize_t end_index = std::min(Tensor::handleNeg(stop_, length), length);
if (step_ > 0) {
for (; index < end_index; index += step_) {
indices.push_back(index);
}
} else {
for (; index > end_index; index += step_) {
indices.push_back(index);
}
}
return indices;
}
bool valid() { return !(start_ == 0 && stop_ == 0 && step_ == 0); }
dsize_t start_;
dsize_t stop_;
dsize_t step_;
};
class SliceOp : public TensorOp {
public:
explicit SliceOp(std::vector<dsize_t> indices) : indices_(std::move(indices)) {}
explicit SliceOp(Slice slice) : slice_(slice) {}
explicit SliceOp(bool all) : all_(all) {}
~SliceOp() override = default;
void Print(std::ostream &out) const override { out << "SliceOp"; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
private:
// only on of the following will be valid
// given indices to slice the Tensor. Empty vector if invalid.
std::vector<dsize_t> indices_;
// Slice object. All start, stop and step are 0 if invalid.
Slice slice_;
// Flag to read all indcies in the dim.
bool all_ = false;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_
......@@ -17,7 +17,8 @@ This module c_transforms provides common operations, including OneHotOp and Type
"""
import numpy as np
import mindspore._c_dataengine as cde
from .validators import check_num_classes, check_de_type, check_fill_value
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op
from ..core.datatypes import mstype_to_detype
......@@ -64,3 +65,46 @@ class TypeCast(cde.TypeCastOp):
data_type = mstype_to_detype(data_type)
self.data_type = str(data_type)
super().__init__(data_type)
class Slice(cde.SliceOp):
"""
Slice operation to extract a tensor out using the given n slices.
The functionality of Slice is similar to NumPy indexing feature.
(Currently only rank 1 Tensors are supported)
Args:
*slices: Maximum n number of objects to slice a tensor of rank n.
One object in slices can be one of:
1. int: slice this index only. Negative index is supported.
2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`.
3. None: slice the whole dimension. Similar to `:` in python indexing.
4. Ellipses ...: slice all dimensions between the two slices.
Examples:
>>> # Data before
>>> # | col |
>>> # +---------+
>>> # | [1,2,3] |
>>> # +---------|
>>> data = data.map(operations=Slice(slice(1,3))) # slice indices 1 and 2 only
>>> # Data after
>>> # | col |
>>> # +------------+
>>> # | [1,2] |
>>> # +------------|
"""
@check_slice_op
def __init__(self, *slices):
dim0 = slices[0]
if isinstance(dim0, int):
dim0 = [dim0]
elif dim0 is None:
dim0 = True
elif isinstance(dim0, slice):
dim0 = (dim0.start, dim0.stop, dim0.step)
elif dim0 is Ellipsis:
dim0 = True
super().__init__(dim0)
......@@ -15,6 +15,7 @@
"""Validators for TensorOps.
"""
from functools import wraps
from mindspore._c_expression import typing
# POS_INT_MIN is used to limit values from starting from 0
......@@ -195,3 +196,20 @@ def check_de_type(method):
return method(self, **kwargs)
return new_method
def check_slice_op(method):
"""Wrapper method to check the parameters of slice."""
@wraps(method)
def new_method(self, *args):
for i, arg in enumerate(args):
if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)):
raise TypeError("Indexing of dim " + str(i) + "is not of valid type")
if isinstance(arg, list):
for a in arg:
if not isinstance(a, int):
raise TypeError("Index " + a + " is not an int")
return method(self, *args)
return new_method
......@@ -28,17 +28,13 @@ using namespace mindspore::dataset;
namespace py = pybind11;
class MindDataTestTensorDE : public UT::Common {
public:
MindDataTestTensorDE() {}
void SetUp() {
GlobalInit();
}
void SetUp() { GlobalInit(); }
};
TEST_F(MindDataTestTensorDE, Basics) {
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64));
ASSERT_TRUE((t->AllocateBuffer(t->SizeInBytes())).IsOk());
......@@ -167,8 +163,7 @@ TEST_F(MindDataTestTensorDE, InsertTensor) {
// Test the bug of Tensor::ToString will exec failed for Tensor which store bool values
TEST_F(MindDataTestTensorDE, BoolTensor) {
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2}),
DataType(DataType::DE_BOOL));
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2}), DataType(DataType::DE_BOOL));
t->SetItemAt<bool>({0}, true);
t->SetItemAt<bool>({1}, true);
std::string out = t->ToString();
......@@ -255,14 +250,19 @@ void checkCvMat(TensorShape shape, DataType type) {
} else {
ASSERT_EQ(m.size[0], shape[0]);
}
if (shape.Rank() == 3) { ASSERT_EQ(m.channels(), shape[2]); }
if (shape.Rank() == 3) {
ASSERT_EQ(m.channels(), shape[2]);
}
ASSERT_EQ(m.dims, 2);
ASSERT_EQ(m.size.dims(), 2);
if (shape.Rank() > 0) { ASSERT_EQ(m.rows, shape[0]); }
if (shape.Rank() > 1) { ASSERT_EQ(m.cols, shape[1]); }
if (shape.Rank() > 0) {
ASSERT_EQ(m.rows, shape[0]);
}
if (shape.Rank() > 1) {
ASSERT_EQ(m.cols, shape[1]);
}
} else {
for (dsize_t i = 0; i < shape.Rank(); i++)
ASSERT_EQ(m.size[static_cast<int>(i)], shape[i]);
for (dsize_t i = 0; i < shape.Rank(); i++) ASSERT_EQ(m.size[static_cast<int>(i)], shape[i]);
ASSERT_EQ(m.dims, shape.Rank());
ASSERT_EQ(m.size.dims(), shape.Rank());
ASSERT_EQ(m.rows, -1);
......@@ -394,3 +394,16 @@ TEST_F(MindDataTestTensorDE, TensorIterator) {
}
ASSERT_TRUE(ctr == 6);
}
TEST_F(MindDataTestTensorDE, TensorSlice) {
std::shared_ptr<Tensor> t;
Tensor::CreateTensor(&t, std::vector<dsize_t>{0, 1, 2, 3, 4});
std::shared_ptr<Tensor> t2;
auto x = std::vector<dsize_t>{0, 3, 4};
std::shared_ptr<Tensor> expected;
Tensor::CreateTensor(&expected, x);
t->Slice(&t2, x);
ASSERT_EQ(*t2, *expected);
t->Slice(&t2, std::vector<dsize_t>{0, 1, 2, 3, 4});
ASSERT_EQ(*t2, *t);
}
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing TypeCast op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as ops
def slice_compare(array, indexing):
data = ds.NumpySlicesDataset([array])
array = np.array(array)
data = data.map(operations=ops.Slice(indexing))
for d in data:
if indexing is None:
array = array[:]
else:
array = array[indexing]
np.testing.assert_array_equal(array, d[0])
def test_slice_all():
slice_compare([1, 2, 3, 4, 5], None)
slice_compare([1, 2, 3, 4, 5], ...)
def test_slice_single_index():
slice_compare([1, 2, 3, 4, 5], 0)
slice_compare([1, 2, 3, 4, 5], 4)
slice_compare([1, 2, 3, 4, 5], 2)
slice_compare([1, 2, 3, 4, 5], -1)
slice_compare([1, 2, 3, 4, 5], -5)
slice_compare([1, 2, 3, 4, 5], -3)
def test_slice_list_index():
slice_compare([1, 2, 3, 4, 5], [0, 1, 4])
slice_compare([1, 2, 3, 4, 5], [4, 1, 0])
slice_compare([1, 2, 3, 4, 5], [-1, 1, 0])
slice_compare([1, 2, 3, 4, 5], [-1, -4, -2])
slice_compare([1, 2, 3, 4, 5], [3, 3, 3])
slice_compare([1, 2, 3, 4, 5], [1, 1, 1, 1, 1])
def test_slice_slice_obj_2s():
slice_compare([1, 2, 3, 4, 5], slice(0, 2))
slice_compare([1, 2, 3, 4, 5], slice(2, 4))
slice_compare([1, 2, 3, 4, 5], slice(4, 10))
def test_slice_slice_obj_1s():
slice_compare([1, 2, 3, 4, 5], slice(1))
slice_compare([1, 2, 3, 4, 5], slice(4))
slice_compare([1, 2, 3, 4, 5], slice(10))
def test_slice_slice_obj_3s():
slice_compare([1, 2, 3, 4, 5], slice(0, 2, 1))
slice_compare([1, 2, 3, 4, 5], slice(0, 4, 1))
slice_compare([1, 2, 3, 4, 5], slice(0, 10, 1))
slice_compare([1, 2, 3, 4, 5], slice(0, 5, 2))
slice_compare([1, 2, 3, 4, 5], slice(0, 2, 2))
slice_compare([1, 2, 3, 4, 5], slice(0, 1, 2))
slice_compare([1, 2, 3, 4, 5], slice(4, 5, 1))
slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3))
def test_slice_slice_obj_3s_double():
slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1))
slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1))
slice_compare([1., 2., 3., 4., 5.], slice(0, 10, 1))
slice_compare([1., 2., 3., 4., 5.], slice(0, 5, 2))
slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 2))
slice_compare([1., 2., 3., 4., 5.], slice(0, 1, 2))
slice_compare([1., 2., 3., 4., 5.], slice(4, 5, 1))
slice_compare([1., 2., 3., 4., 5.], slice(2, 5, 3))
def test_slice_slice_obj_neg():
slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -1))
slice_compare([1, 2, 3, 4, 5], slice(-1))
slice_compare([1, 2, 3, 4, 5], slice(-2))
slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -2))
slice_compare([1, 2, 3, 4, 5], slice(-5, -1, 2))
slice_compare([1, 2, 3, 4, 5], slice(-5, -1))
def test_slice_exceptions():
with pytest.raises(RuntimeError) as info:
slice_compare([1, 2, 3, 4, 5], 5)
assert "Index 5 is out of bounds [0,5)" in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([1, 2, 3, 4, 5], slice(0))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([1, 2, 3, 4, 5], slice(-1, -5, 1))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
def test_slice_all_str():
slice_compare([b"1", b"2", b"3", b"4", b"5"], None)
slice_compare([b"1", b"2", b"3", b"4", b"5"], ...)
def test_slice_single_index_str():
slice_compare([b"1", b"2", b"3", b"4", b"5"], 0)
slice_compare([b"1", b"2", b"3", b"4", b"5"], 4)
slice_compare([b"1", b"2", b"3", b"4", b"5"], 2)
slice_compare([b"1", b"2", b"3", b"4", b"5"], -1)
slice_compare([b"1", b"2", b"3", b"4", b"5"], -5)
slice_compare([b"1", b"2", b"3", b"4", b"5"], -3)
def test_slice_list_index_str():
slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1, 4])
slice_compare([b"1", b"2", b"3", b"4", b"5"], [4, 1, 0])
slice_compare([b"1", b"2", b"3", b"4", b"5"], [-1, 1, 0])
slice_compare([b"1", b"2", b"3", b"4", b"5"], [-1, -4, -2])
slice_compare([b"1", b"2", b"3", b"4", b"5"], [3, 3, 3])
slice_compare([b"1", b"2", b"3", b"4", b"5"], [1, 1, 1, 1, 1])
def test_slice_slice_obj_2s_str():
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 4))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4, 10))
def test_slice_slice_obj_1s_str():
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(1))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(10))
def test_slice_slice_obj_3s_str():
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 1))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 4, 1))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 10, 1))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 5, 2))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 2))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 1, 2))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4, 5, 1))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 5, 3))
def test_slice_slice_obj_neg_str():
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -1))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-2))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -2))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1, 2))
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1))
def test_slice_exceptions_str():
with pytest.raises(RuntimeError) as info:
slice_compare([b"1", b"2", b"3", b"4", b"5"], 5)
assert "Index 5 is out of bounds [0,5)" in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, 1))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
if __name__ == "__main__":
test_slice_all()
test_slice_single_index()
test_slice_list_index()
test_slice_slice_obj_3s()
test_slice_slice_obj_2s()
test_slice_slice_obj_1s()
test_slice_slice_obj_neg()
test_slice_exceptions()
test_slice_slice_obj_3s_double()
test_slice_all_str()
test_slice_single_index_str()
test_slice_list_index_str()
test_slice_slice_obj_3s_str()
test_slice_slice_obj_2s_str()
test_slice_slice_obj_1s_str()
test_slice_slice_obj_neg_str()
test_slice_exceptions_str()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册