Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b520ca90
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b520ca90
编写于
5月 07, 2020
作者:
L
liyong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix pk sampler in mindrecord
上级
5a03bd80
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
37 addition
and
5 deletion
+37
-5
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+7
-3
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+2
-2
tests/ut/python/dataset/test_minddataset_exception.py
tests/ut/python/dataset/test_minddataset_exception.py
+14
-0
tests/ut/python/dataset/test_minddataset_sampler.py
tests/ut/python/dataset/test_minddataset_sampler.py
+14
-0
未找到文件。
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
b520ca90
...
...
@@ -316,11 +316,15 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
}
MSRStatus
ShardReader
::
GetAllClasses
(
const
std
::
string
&
category_field
,
std
::
set
<
std
::
string
>
&
categories
)
{
if
(
column_schema_id_
.
find
(
category_field
)
==
column_schema_id_
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Field "
<<
category_field
<<
" does not exist."
;
std
::
map
<
std
::
string
,
uint64_t
>
index_columns
;
for
(
auto
&
field
:
get_shard_header
()
->
get_fields
())
{
index_columns
[
field
.
second
]
=
field
.
first
;
}
if
(
index_columns
.
find
(
category_field
)
==
index_columns
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Index field "
<<
category_field
<<
" does not exist."
;
return
FAILED
;
}
auto
ret
=
ShardIndexGenerator
::
GenerateFieldName
(
std
::
make_pair
(
column_schema_id_
[
category_field
],
category_field
));
auto
ret
=
ShardIndexGenerator
::
GenerateFieldName
(
std
::
make_pair
(
index_columns
[
category_field
],
category_field
));
if
(
SUCCESS
!=
ret
.
first
)
{
return
FAILED
;
}
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
b520ca90
...
...
@@ -2224,8 +2224,8 @@ class MindDataset(SourceDataset):
if
block_reader
is
True
and
sampler
is
not
None
:
raise
ValueError
(
"block reader not allowed true when use sampler"
)
if
shuffle
is
Tru
e
and
sampler
is
not
None
:
raise
ValueError
(
"shuffle not allowed
true
when use sampler"
)
if
shuffle
is
not
Non
e
and
sampler
is
not
None
:
raise
ValueError
(
"shuffle not allowed when use sampler"
)
if
block_reader
is
False
and
sampler
is
None
:
self
.
global_shuffle
=
not
bool
(
shuffle
is
False
)
...
...
tests/ut/python/dataset/test_minddataset_exception.py
浏览文件 @
b520ca90
...
...
@@ -97,3 +97,17 @@ def test_cv_minddataset_pk_sample_error_class_column():
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
def
test_cv_minddataset_pk_sample_exclusive_shuffle
():
create_cv_mindrecord
(
1
)
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
sampler
=
ds
.
PKSampler
(
2
)
with
pytest
.
raises
(
Exception
,
match
=
"shuffle not allowed when use sampler"
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
sampler
=
sampler
,
shuffle
=
False
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
tests/ut/python/dataset/test_minddataset_sampler.py
浏览文件 @
b520ca90
...
...
@@ -60,7 +60,21 @@ def add_and_remove_cv_file():
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_cv_minddataset_pk_sample_no_column
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
num_readers
=
4
sampler
=
ds
.
PKSampler
(
2
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
None
,
num_readers
,
sampler
=
sampler
)
assert
data_set
.
get_dataset_size
()
==
6
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[file_name]:
\
{}------------------------"
.
format
(
""
.
join
([
chr
(
x
)
for
x
in
item
[
"file_name"
]])))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
def
test_cv_minddataset_pk_sample_basic
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录