Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
beb436f4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
beb436f4
编写于
6月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2398 Bug in Slice when multiple rows are used
Merge pull request !2398 from h.farahat/slice_bug
上级
a9c309da
68030e6a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
19 addition
and
2 deletion
+19
-2
mindspore/ccsrc/dataset/kernels/data/slice_op.cc
mindspore/ccsrc/dataset/kernels/data/slice_op.cc
+2
-2
tests/ut/python/dataset/test_slice_op.py
tests/ut/python/dataset/test_slice_op.py
+17
-0
未找到文件。
mindspore/ccsrc/dataset/kernels/data/slice_op.cc
浏览文件 @
beb436f4
...
@@ -33,8 +33,8 @@ Status SliceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
...
@@ -33,8 +33,8 @@ Status SliceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
// if slice object was provided, indices should be empty. Generate indices from the slice object.
// if slice object was provided, indices should be empty. Generate indices from the slice object.
if
(
slice_
.
valid
()
&&
indices_
.
empty
())
{
if
(
slice_
.
valid
()
&&
indices_
.
empty
())
{
dsize_t
len
=
input
->
shape
()[
0
];
dsize_t
len
=
input
->
shape
()[
0
];
indices_
=
slice_
.
Indices
(
len
);
std
::
vector
<
dsize_t
>
indices
=
slice_
.
Indices
(
len
);
return
input
->
Slice
(
output
,
indices
_
);
return
input
->
Slice
(
output
,
indices
);
}
}
// if indices are not empty, slices should be invalid, use indices_ to slice
// if indices are not empty, slices should be invalid, use indices_ to slice
...
...
tests/ut/python/dataset/test_slice_op.py
浏览文件 @
beb436f4
...
@@ -80,6 +80,22 @@ def test_slice_slice_obj_3s():
...
@@ -80,6 +80,22 @@ def test_slice_slice_obj_3s():
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
2
,
5
,
3
))
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
2
,
5
,
3
))
def
test_slice_multiple_rows
():
dataset
=
[[
1
,
2
],
[
3
,
4
,
5
],
[
1
],
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]]
def
gen
():
for
row
in
dataset
:
yield
(
np
.
array
(
row
),)
data
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"col"
])
indexing
=
slice
(
0
,
4
)
data
=
data
.
map
(
operations
=
ops
.
Slice
(
indexing
))
for
i
,
d
in
enumerate
(
data
):
array
=
np
.
array
(
dataset
[
i
])
array
=
array
[
indexing
]
np
.
testing
.
assert_array_equal
(
array
,
d
[
0
])
def
test_slice_slice_obj_3s_double
():
def
test_slice_slice_obj_3s_double
():
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
2
,
1
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
2
,
1
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
4
,
1
))
slice_compare
([
1.
,
2.
,
3.
,
4.
,
5.
],
slice
(
0
,
4
,
1
))
...
@@ -217,3 +233,4 @@ if __name__ == "__main__":
...
@@ -217,3 +233,4 @@ if __name__ == "__main__":
test_slice_slice_obj_1s_str
()
test_slice_slice_obj_1s_str
()
test_slice_slice_obj_neg_str
()
test_slice_slice_obj_neg_str
()
test_slice_exceptions_str
()
test_slice_exceptions_str
()
test_slice_multiple_rows
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录