Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
e2012a1d
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e2012a1d
编写于
6月 13, 2020
作者:
H
hesham
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Slice Op
上级
cc80dca7
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
537 addition
and
16 deletion
+537
-16
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+33
-1
mindspore/ccsrc/dataset/core/tensor.cc
mindspore/ccsrc/dataset/core/tensor.cc
+55
-0
mindspore/ccsrc/dataset/core/tensor.h
mindspore/ccsrc/dataset/core/tensor.h
+16
-0
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
+2
-1
mindspore/ccsrc/dataset/kernels/data/slice_op.cc
mindspore/ccsrc/dataset/kernels/data/slice_op.cc
+48
-0
mindspore/ccsrc/dataset/kernels/data/slice_op.h
mindspore/ccsrc/dataset/kernels/data/slice_op.h
+83
-0
mindspore/dataset/transforms/c_transforms.py
mindspore/dataset/transforms/c_transforms.py
+45
-1
mindspore/dataset/transforms/validators.py
mindspore/dataset/transforms/validators.py
+18
-0
tests/ut/cpp/dataset/tensor_test.cc
tests/ut/cpp/dataset/tensor_test.cc
+26
-13
tests/ut/python/dataset/test_slice_op.py
tests/ut/python/dataset/test_slice_op.py
+211
-0
未找到文件。
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
e2012a1d
...
...
@@ -37,8 +37,9 @@
#include "dataset/kernels/image/resize_bilinear_op.h"
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/kernels/data/slice_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
...
...
@@ -368,6 +369,37 @@ void bindTensorOps2(py::module *m) {
*
m
,
"FillOp"
,
"Tensor operation to return tensor filled with same value as input fill value."
)
.
def
(
py
::
init
<
std
::
shared_ptr
<
Tensor
>>
());
(
void
)
py
::
class_
<
SliceOp
,
TensorOp
,
std
::
shared_ptr
<
SliceOp
>>
(
*
m
,
"SliceOp"
,
""
)
.
def
(
py
::
init
<
bool
>
())
.
def
(
py
::
init
([](
const
py
::
list
&
py_list
)
{
std
::
vector
<
dsize_t
>
c_list
;
for
(
auto
l
:
py_list
)
{
if
(
!
l
.
is_none
())
{
c_list
.
push_back
(
py
::
reinterpret_borrow
<
py
::
int_
>
(
l
));
}
}
return
std
::
make_shared
<
SliceOp
>
(
c_list
);
}))
.
def
(
py
::
init
([](
const
py
::
tuple
&
py_slice
)
{
if
(
py_slice
.
size
()
!=
3
)
{
THROW_IF_ERROR
(
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Wrong slice object"
));
}
Slice
c_slice
;
if
(
!
py_slice
[
0
].
is_none
()
&&
!
py_slice
[
1
].
is_none
()
&&
!
py_slice
[
2
].
is_none
())
{
c_slice
=
Slice
(
py
::
reinterpret_borrow
<
py
::
int_
>
(
py_slice
[
0
]),
py
::
reinterpret_borrow
<
py
::
int_
>
(
py_slice
[
1
]),
py
::
reinterpret_borrow
<
py
::
int_
>
(
py_slice
[
2
]));
}
else
if
(
py_slice
[
0
].
is_none
()
&&
py_slice
[
2
].
is_none
())
{
c_slice
=
Slice
(
py
::
reinterpret_borrow
<
py
::
int_
>
(
py_slice
[
1
]));
}
else
if
(
!
py_slice
[
0
].
is_none
()
&&
!
py_slice
[
1
].
is_none
())
{
c_slice
=
Slice
(
py
::
reinterpret_borrow
<
py
::
int_
>
(
py_slice
[
0
]),
py
::
reinterpret_borrow
<
py
::
int_
>
(
py_slice
[
1
]));
}
if
(
!
c_slice
.
valid
())
{
THROW_IF_ERROR
(
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Wrong slice object"
));
}
return
std
::
make_shared
<
SliceOp
>
(
c_slice
);
}));
(
void
)
py
::
class_
<
RandomRotationOp
,
TensorOp
,
std
::
shared_ptr
<
RandomRotationOp
>>
(
*
m
,
"RandomRotationOp"
,
"Tensor operation to apply RandomRotation."
...
...
mindspore/ccsrc/dataset/core/tensor.cc
浏览文件 @
e2012a1d
...
...
@@ -916,6 +916,61 @@ Status Tensor::CopyLastDimAt(const std::shared_ptr<Tensor> &src, const std::vect
CHECK_FAIL_RETURN_UNEXPECTED
(
memcpy_s
(
dst_addr
,
len
,
src_addr
,
len
)
==
0
,
"memcpy error"
);
return
Status
::
OK
();
}
Status
Tensor
::
Slice
(
std
::
shared_ptr
<
Tensor
>
*
out
,
const
std
::
vector
<
dsize_t
>
&
indices
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
shape_
.
Rank
()
==
1
,
"Currently Slice work with rank 1 tensors only."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
!
indices
.
empty
(),
"Indices are empty, generated tensor would be empty."
);
if
(
type_
.
IsNumeric
())
{
return
SliceNumeric
(
out
,
indices
);
}
else
{
return
SliceString
(
out
,
indices
);
}
}
Status
Tensor
::
SliceNumeric
(
std
::
shared_ptr
<
Tensor
>
*
out
,
const
std
::
vector
<
dsize_t
>
&
indices
)
{
RETURN_IF_NOT_OK
(
CreateTensor
(
out
,
TensorImpl
::
kFlexible
,
TensorShape
({
static_cast
<
dsize_t
>
(
indices
.
size
())}),
type_
));
(
*
out
)
->
GetMutableBuffer
();
dsize_t
out_index
=
0
;
dsize_t
dim_length
=
shape_
[
0
];
dsize_t
type_size
=
type_
.
SizeInBytes
();
dsize_t
src_start
=
handleNeg
(
indices
[
0
],
dim_length
);
uchar
*
dst_addr
=
(
*
out
)
->
data_
;
dsize_t
count
=
1
;
for
(
dsize_t
i
=
0
;
i
<
indices
.
size
();
i
++
)
{
dsize_t
cur_index
=
handleNeg
(
indices
[
i
],
dim_length
);
CHECK_FAIL_RETURN_UNEXPECTED
(
cur_index
>=
0
&&
cur_index
<
dim_length
,
"Index "
+
std
::
to_string
(
indices
[
i
])
+
" is out of bounds [0,"
+
std
::
to_string
(
dim_length
)
+
")"
);
if
(
i
<
indices
.
size
()
-
1
)
{
dsize_t
next_index
=
handleNeg
(
indices
[
i
+
1
],
dim_length
);
if
(
next_index
==
cur_index
+
1
)
{
count
++
;
continue
;
}
}
memcpy_s
(
dst_addr
+
out_index
*
type_size
,
(
*
out
)
->
SizeInBytes
(),
data_
+
src_start
*
type_size
,
count
*
type_size
);
out_index
+=
count
;
if
(
i
<
indices
.
size
()
-
1
)
{
src_start
=
handleNeg
(
indices
[
i
+
1
],
dim_length
);
// next index
}
count
=
1
;
}
return
Status
::
OK
();
}
Status
Tensor
::
SliceString
(
std
::
shared_ptr
<
Tensor
>
*
out
,
const
std
::
vector
<
dsize_t
>
&
indices
)
{
dsize_t
dim_length
=
shape_
[
0
];
std
::
vector
<
std
::
string
>
strings
;
for
(
dsize_t
index
:
indices
)
{
dsize_t
cur_index
=
handleNeg
(
index
,
dim_length
);
CHECK_FAIL_RETURN_UNEXPECTED
(
cur_index
>=
0
&&
cur_index
<
dim_length
,
"Index "
+
std
::
to_string
(
index
)
+
" is out of bounds [0,"
+
std
::
to_string
(
dim_length
)
+
")"
);
std
::
string_view
sv
;
GetItemAt
(
&
sv
,
{
cur_index
});
strings
.
emplace_back
(
sv
);
}
return
CreateTensor
(
out
,
strings
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/core/tensor.h
浏览文件 @
e2012a1d
...
...
@@ -347,6 +347,22 @@ class Tensor {
return
ss
.
str
();
}
// Handle negative indices.
static
inline
dsize_t
handleNeg
(
dsize_t
index
,
dsize_t
length
)
{
return
(
index
<
0
)
?
(
index
+
length
)
:
index
;
}
// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported.
// Based on the type of tensor, SliceNumeric or SliceString will be called
// @param out Tensor
// @param indices vector of indices
// @return Status error code
Status
Slice
(
std
::
shared_ptr
<
Tensor
>
*
out
,
const
std
::
vector
<
dsize_t
>
&
indices
);
// Slice numeric tensors.
Status
SliceNumeric
(
std
::
shared_ptr
<
Tensor
>
*
out
,
const
std
::
vector
<
dsize_t
>
&
indices
);
// Slice string tensors
Status
SliceString
(
std
::
shared_ptr
<
Tensor
>
*
out
,
const
std
::
vector
<
dsize_t
>
&
indices
);
// Constructs numpy array from input tensor
// @param data this data is the location of python data
// @return Status code
...
...
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
浏览文件 @
e2012a1d
...
...
@@ -5,4 +5,5 @@ add_library(kernels-data OBJECT
one_hot_op.cc
type_cast_op.cc
to_float16_op.cc
fill_op.cc
)
fill_op.cc
slice_op.cc
)
mindspore/ccsrc/dataset/kernels/data/slice_op.cc
0 → 100644
浏览文件 @
e2012a1d
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/data/slice_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/data/data_utils.h"
#include "dataset/kernels/tensor_op.h"
namespace
mindspore
{
namespace
dataset
{
Status
SliceOp
::
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
{
IO_CHECK
(
input
,
output
);
CHECK_FAIL_RETURN_UNEXPECTED
(
input
->
shape
().
Rank
()
==
1
,
"SliceOp supports 1D Tensors only for now."
);
// if `all` flag is true, output is just the input.
if
(
all_
)
{
*
output
=
input
;
return
Status
::
OK
();
}
// if slice object was provided, indices should be empty. Generate indices from the slice object.
if
(
slice_
.
valid
()
&&
indices_
.
empty
())
{
dsize_t
len
=
input
->
shape
()[
0
];
indices_
=
slice_
.
Indices
(
len
);
return
input
->
Slice
(
output
,
indices_
);
}
// if indices are not empty, slices should be invalid, use indices_ to slice
if
(
!
indices_
.
empty
()
&&
!
slice_
.
valid
())
{
return
input
->
Slice
(
output
,
indices_
);
}
RETURN_STATUS_UNEXPECTED
(
"The indexing parameters are invalid"
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/kernels/data/slice_op.h
0 → 100644
浏览文件 @
e2012a1d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_DATA_SLICE_OP_H_
#define DATASET_KERNELS_DATA_SLICE_OP_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
namespace
mindspore
{
namespace
dataset
{
class
Slice
{
public:
Slice
()
:
start_
(
0
),
stop_
(
0
),
step_
(
0
)
{}
Slice
(
dsize_t
start
,
dsize_t
stop
,
dsize_t
step
)
:
start_
(
start
),
stop_
(
stop
),
step_
(
step
)
{}
Slice
(
dsize_t
start
,
dsize_t
stop
)
:
start_
(
start
),
stop_
(
stop
),
step_
(
1
)
{}
explicit
Slice
(
dsize_t
stop
)
:
start_
(
0
),
stop_
(
stop
),
step_
(
1
)
{}
std
::
vector
<
dsize_t
>
Indices
(
dsize_t
length
)
{
std
::
vector
<
dsize_t
>
indices
;
dsize_t
index
=
std
::
min
(
Tensor
::
handleNeg
(
start_
,
length
),
length
);
dsize_t
end_index
=
std
::
min
(
Tensor
::
handleNeg
(
stop_
,
length
),
length
);
if
(
step_
>
0
)
{
for
(;
index
<
end_index
;
index
+=
step_
)
{
indices
.
push_back
(
index
);
}
}
else
{
for
(;
index
>
end_index
;
index
+=
step_
)
{
indices
.
push_back
(
index
);
}
}
return
indices
;
}
bool
valid
()
{
return
!
(
start_
==
0
&&
stop_
==
0
&&
step_
==
0
);
}
dsize_t
start_
;
dsize_t
stop_
;
dsize_t
step_
;
};
class
SliceOp
:
public
TensorOp
{
public:
explicit
SliceOp
(
std
::
vector
<
dsize_t
>
indices
)
:
indices_
(
std
::
move
(
indices
))
{}
explicit
SliceOp
(
Slice
slice
)
:
slice_
(
slice
)
{}
explicit
SliceOp
(
bool
all
)
:
all_
(
all
)
{}
~
SliceOp
()
override
=
default
;
void
Print
(
std
::
ostream
&
out
)
const
override
{
out
<<
"SliceOp"
;
}
Status
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
override
;
private:
// only on of the following will be valid
// given indices to slice the Tensor. Empty vector if invalid.
std
::
vector
<
dsize_t
>
indices_
;
// Slice object. All start, stop and step are 0 if invalid.
Slice
slice_
;
// Flag to read all indcies in the dim.
bool
all_
=
false
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_
mindspore/dataset/transforms/c_transforms.py
浏览文件 @
e2012a1d
...
...
@@ -17,7 +17,8 @@ This module c_transforms provides common operations, including OneHotOp and Type
"""
import
numpy
as
np
import
mindspore._c_dataengine
as
cde
from
.validators
import
check_num_classes
,
check_de_type
,
check_fill_value
from
.validators
import
check_num_classes
,
check_de_type
,
check_fill_value
,
check_slice_op
from
..core.datatypes
import
mstype_to_detype
...
...
@@ -64,3 +65,46 @@ class TypeCast(cde.TypeCastOp):
data_type
=
mstype_to_detype
(
data_type
)
self
.
data_type
=
str
(
data_type
)
super
().
__init__
(
data_type
)
class
Slice
(
cde
.
SliceOp
):
"""
Slice operation to extract a tensor out using the given n slices.
The functionality of Slice is similar to NumPy indexing feature.
(Currently only rank 1 Tensors are supported)
Args:
*slices: Maximum n number of objects to slice a tensor of rank n.
One object in slices can be one of:
1. int: slice this index only. Negative index is supported.
2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`.
3. None: slice the whole dimension. Similar to `:` in python indexing.
4. Ellipses ...: slice all dimensions between the two slices.
Examples:
>>> # Data before
>>> # | col |
>>> # +---------+
>>> # | [1,2,3] |
>>> # +---------|
>>> data = data.map(operations=Slice(slice(1,3))) # slice indices 1 and 2 only
>>> # Data after
>>> # | col |
>>> # +------------+
>>> # | [1,2] |
>>> # +------------|
"""
@
check_slice_op
def
__init__
(
self
,
*
slices
):
dim0
=
slices
[
0
]
if
isinstance
(
dim0
,
int
):
dim0
=
[
dim0
]
elif
dim0
is
None
:
dim0
=
True
elif
isinstance
(
dim0
,
slice
):
dim0
=
(
dim0
.
start
,
dim0
.
stop
,
dim0
.
step
)
elif
dim0
is
Ellipsis
:
dim0
=
True
super
().
__init__
(
dim0
)
mindspore/dataset/transforms/validators.py
浏览文件 @
e2012a1d
...
...
@@ -15,6 +15,7 @@
"""Validators for TensorOps.
"""
from
functools
import
wraps
from
mindspore._c_expression
import
typing
# POS_INT_MIN is used to limit values from starting from 0
...
...
@@ -195,3 +196,20 @@ def check_de_type(method):
return
method
(
self
,
**
kwargs
)
return
new_method
def
check_slice_op
(
method
):
"""Wrapper method to check the parameters of slice."""
@
wraps
(
method
)
def
new_method
(
self
,
*
args
):
for
i
,
arg
in
enumerate
(
args
):
if
arg
is
not
None
and
arg
is
not
Ellipsis
and
not
isinstance
(
arg
,
(
int
,
slice
,
list
)):
raise
TypeError
(
"Indexing of dim "
+
str
(
i
)
+
"is not of valid type"
)
if
isinstance
(
arg
,
list
):
for
a
in
arg
:
if
not
isinstance
(
a
,
int
):
raise
TypeError
(
"Index "
+
a
+
" is not an int"
)
return
method
(
self
,
*
args
)
return
new_method
tests/ut/cpp/dataset/tensor_test.cc
浏览文件 @
e2012a1d
...
...
@@ -28,17 +28,13 @@ using namespace mindspore::dataset;
namespace
py
=
pybind11
;
class
MindDataTestTensorDE
:
public
UT
::
Common
{
public:
MindDataTestTensorDE
()
{}
MindDataTestTensorDE
()
{}
void
SetUp
()
{
GlobalInit
();
}
void
SetUp
()
{
GlobalInit
();
}
};
TEST_F
(
MindDataTestTensorDE
,
Basics
)
{
std
::
shared_ptr
<
Tensor
>
t
=
std
::
make_shared
<
Tensor
>
(
TensorShape
({
2
,
3
}),
DataType
(
DataType
::
DE_UINT64
));
ASSERT_TRUE
((
t
->
AllocateBuffer
(
t
->
SizeInBytes
())).
IsOk
());
...
...
@@ -167,8 +163,7 @@ TEST_F(MindDataTestTensorDE, InsertTensor) {
// Test the bug of Tensor::ToString will exec failed for Tensor which store bool values
TEST_F
(
MindDataTestTensorDE
,
BoolTensor
)
{
std
::
shared_ptr
<
Tensor
>
t
=
std
::
make_shared
<
Tensor
>
(
TensorShape
({
2
}),
DataType
(
DataType
::
DE_BOOL
));
std
::
shared_ptr
<
Tensor
>
t
=
std
::
make_shared
<
Tensor
>
(
TensorShape
({
2
}),
DataType
(
DataType
::
DE_BOOL
));
t
->
SetItemAt
<
bool
>
({
0
},
true
);
t
->
SetItemAt
<
bool
>
({
1
},
true
);
std
::
string
out
=
t
->
ToString
();
...
...
@@ -255,14 +250,19 @@ void checkCvMat(TensorShape shape, DataType type) {
}
else
{
ASSERT_EQ
(
m
.
size
[
0
],
shape
[
0
]);
}
if
(
shape
.
Rank
()
==
3
)
{
ASSERT_EQ
(
m
.
channels
(),
shape
[
2
]);
}
if
(
shape
.
Rank
()
==
3
)
{
ASSERT_EQ
(
m
.
channels
(),
shape
[
2
]);
}
ASSERT_EQ
(
m
.
dims
,
2
);
ASSERT_EQ
(
m
.
size
.
dims
(),
2
);
if
(
shape
.
Rank
()
>
0
)
{
ASSERT_EQ
(
m
.
rows
,
shape
[
0
]);
}
if
(
shape
.
Rank
()
>
1
)
{
ASSERT_EQ
(
m
.
cols
,
shape
[
1
]);
}
if
(
shape
.
Rank
()
>
0
)
{
ASSERT_EQ
(
m
.
rows
,
shape
[
0
]);
}
if
(
shape
.
Rank
()
>
1
)
{
ASSERT_EQ
(
m
.
cols
,
shape
[
1
]);
}
}
else
{
for
(
dsize_t
i
=
0
;
i
<
shape
.
Rank
();
i
++
)
ASSERT_EQ
(
m
.
size
[
static_cast
<
int
>
(
i
)],
shape
[
i
]);
for
(
dsize_t
i
=
0
;
i
<
shape
.
Rank
();
i
++
)
ASSERT_EQ
(
m
.
size
[
static_cast
<
int
>
(
i
)],
shape
[
i
]);
ASSERT_EQ
(
m
.
dims
,
shape
.
Rank
());
ASSERT_EQ
(
m
.
size
.
dims
(),
shape
.
Rank
());
ASSERT_EQ
(
m
.
rows
,
-
1
);
...
...
@@ -394,3 +394,16 @@ TEST_F(MindDataTestTensorDE, TensorIterator) {
}
ASSERT_TRUE
(
ctr
==
6
);
}
TEST_F
(
MindDataTestTensorDE
,
TensorSlice
)
{
std
::
shared_ptr
<
Tensor
>
t
;
Tensor
::
CreateTensor
(
&
t
,
std
::
vector
<
dsize_t
>
{
0
,
1
,
2
,
3
,
4
});
std
::
shared_ptr
<
Tensor
>
t2
;
auto
x
=
std
::
vector
<
dsize_t
>
{
0
,
3
,
4
};
std
::
shared_ptr
<
Tensor
>
expected
;
Tensor
::
CreateTensor
(
&
expected
,
x
);
t
->
Slice
(
&
t2
,
x
);
ASSERT_EQ
(
*
t2
,
*
expected
);
t
->
Slice
(
&
t2
,
std
::
vector
<
dsize_t
>
{
0
,
1
,
2
,
3
,
4
});
ASSERT_EQ
(
*
t2
,
*
t
);
}
tests/ut/python/dataset/test_slice_op.py
0 → 100644
浏览文件 @
e2012a1d
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing TypeCast op in DE
"""
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
ops
def
slice_compare
(
array
,
indexing
):
data
=
ds
.
NumpySlicesDataset
([
array
])
array
=
np
.
array
(
array
)
data
=
data
.
map
(
operations
=
ops
.
Slice
(
indexing
))
for
d
in
data
:
if
indexing
is
None
:
array
=
array
[:]
else
:
array
=
array
[
indexing
]
np
.
testing
.
assert_array_equal
(
array
,
d
[
0
])
def
test_slice_all
():
slice_compare
([
1
,
2
,
3
,
4
,
5
],
None
)
slice_compare
([
1
,
2
,
3
,
4
,
5
],
...)
def
test_slice_single_index
():
slice_compare
([
1
,
2
,
3
,
4
,
5
],
0
)
slice_compare
([
1
,
2
,
3
,
4
,
5
],
4
)
slice_compare
([
1
,
2
,
3
,
4
,
5
],
2
)
slice_compare
([
1
,
2
,
3
,
4
,
5
],
-
1
)
slice_compare
([
1
,
2
,
3
,
4
,
5
],
-
5
)
slice_compare
([
1
,
2
,
3
,
4
,
5
],
-
3
)
def
test_slice_list_index
():
slice_compare
([
1
,
2
,
3
,
4
,
5
],
[
0
,
1
,
4
])
slice_compare
([
1
,
2
,
3
,
4
,
5
],
[
4
,
1
,
0
])
slice_compare
([
1
,
2
,
3
,
4
,
5
],
[
-
1
,
1
,
0
])
slice_compare
([
1
,
2
,
3
,
4
,
5
],
[
-
1
,
-
4
,
-
2
])
slice_compare
([
1
,
2
,
3
,
4
,
5
],
[
3
,
3
,
3
])
slice_compare
([
1
,
2
,
3
,
4
,
5
],
[
1
,
1
,
1
,
1
,
1
])
def
test_slice_slice_obj_2s
():
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
0
,
2
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
2
,
4
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
4
,
10
))
def
test_slice_slice_obj_1s
():
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
1
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
4
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
10
))
def
test_slice_slice_obj_3s
():
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
0
,
2
,
1
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
0
,
4
,
1
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
0
,
10
,
1
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
0
,
5
,
2
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
0
,
2
,
2
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
0
,
1
,
2
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
4
,
5
,
1
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
2
,
5
,
3
))
def
test_slice_slice_obj_3s_double
():
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
2
,
1
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
4
,
1
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
10
,
1
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
5
,
2
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
2
,
2
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
1
,
2
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
4
,
5
,
1
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
2
,
5
,
3
))
def
test_slice_slice_obj_neg
():
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
-
1
,
-
5
,
-
1
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
-
1
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
-
2
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
-
1
,
-
5
,
-
2
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
-
5
,
-
1
,
2
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
-
5
,
-
1
))
def
test_slice_exceptions
():
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
1
,
2
,
3
,
4
,
5
],
5
)
assert
"Index 5 is out of bounds [0,5)"
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
0
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
5
,
10
,
1
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
-
1
,
-
5
,
1
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
def
test_slice_all_str
():
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
None
)
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
...)
def
test_slice_single_index_str
():
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
0
)
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
4
)
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
2
)
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
-
1
)
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
-
5
)
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
-
3
)
def
test_slice_list_index_str
():
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
[
0
,
1
,
4
])
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
[
4
,
1
,
0
])
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
[
-
1
,
1
,
0
])
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
[
-
1
,
-
4
,
-
2
])
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
[
3
,
3
,
3
])
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
[
1
,
1
,
1
,
1
,
1
])
def
test_slice_slice_obj_2s_str
():
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
0
,
2
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
2
,
4
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
4
,
10
))
def
test_slice_slice_obj_1s_str
():
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
1
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
4
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
10
))
def
test_slice_slice_obj_3s_str
():
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
0
,
2
,
1
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
0
,
4
,
1
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
0
,
10
,
1
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
0
,
5
,
2
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
0
,
2
,
2
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
0
,
1
,
2
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
4
,
5
,
1
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
2
,
5
,
3
))
def
test_slice_slice_obj_neg_str
():
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
-
1
,
-
5
,
-
1
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
-
1
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
-
2
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
-
1
,
-
5
,
-
2
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
-
5
,
-
1
,
2
))
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
-
5
,
-
1
))
def
test_slice_exceptions_str
():
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
5
)
assert
"Index 5 is out of bounds [0,5)"
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
0
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
5
,
10
,
1
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
-
1
,
-
5
,
1
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
if
__name__
==
"__main__"
:
test_slice_all
()
test_slice_single_index
()
test_slice_list_index
()
test_slice_slice_obj_3s
()
test_slice_slice_obj_2s
()
test_slice_slice_obj_1s
()
test_slice_slice_obj_neg
()
test_slice_exceptions
()
test_slice_slice_obj_3s_double
()
test_slice_all_str
()
test_slice_single_index_str
()
test_slice_list_index_str
()
test_slice_slice_obj_3s_str
()
test_slice_slice_obj_2s_str
()
test_slice_slice_obj_1s_str
()
test_slice_slice_obj_neg_str
()
test_slice_exceptions_str
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录