Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b37db1ed
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b37db1ed
编写于
4月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!603 [MD] update pk sampler in minddataset
Merge pull request !603 from liyong126/update_pk_sampler
上级
56003327
bfba630a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
35 addition
and
6 deletion
+35
-6
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+3
-3
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+9
-0
mindspore/ccsrc/mindrecord/meta/shard_category.cc
mindspore/ccsrc/mindrecord/meta/shard_category.cc
+1
-1
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+7
-2
tests/ut/python/dataset/test_minddataset_exception.py
tests/ut/python/dataset/test_minddataset_exception.py
+15
-0
未找到文件。
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
b37db1ed
...
...
@@ -435,12 +435,12 @@ void bindSamplerOps(py::module *m) {
.
def
(
py
::
init
<
std
::
vector
<
int64_t
>
,
uint32_t
>
(),
py
::
arg
(
"indices"
),
py
::
arg
(
"seed"
)
=
GetSeed
());
(
void
)
py
::
class_
<
mindrecord
::
ShardPkSample
,
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardPkSample
>>
(
*
m
,
"MindrecordPkSampler"
)
.
def
(
py
::
init
([](
int64_t
kVal
,
bool
shuffle
)
{
.
def
(
py
::
init
([](
int64_t
kVal
,
std
::
string
kColumn
,
bool
shuffle
)
{
if
(
shuffle
==
true
)
{
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
"label"
,
kVal
,
std
::
numeric_limits
<
int64_t
>::
max
(),
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
kColumn
,
kVal
,
std
::
numeric_limits
<
int64_t
>::
max
(),
GetSeed
());
}
else
{
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
"label"
,
kVal
);
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
kColumn
,
kVal
);
}
}));
...
...
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
b37db1ed
...
...
@@ -316,6 +316,10 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
}
MSRStatus
ShardReader
::
GetAllClasses
(
const
std
::
string
&
category_field
,
std
::
set
<
std
::
string
>
&
categories
)
{
if
(
column_schema_id_
.
find
(
category_field
)
==
column_schema_id_
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Field "
<<
category_field
<<
" does not exist."
;
return
FAILED
;
}
auto
ret
=
ShardIndexGenerator
::
GenerateFieldName
(
std
::
make_pair
(
column_schema_id_
[
category_field
],
category_field
));
if
(
SUCCESS
!=
ret
.
first
)
{
return
FAILED
;
...
...
@@ -719,6 +723,11 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
for
(
auto
&
field
:
index_fields
)
{
map_schema_id_fields
[
field
.
second
]
=
field
.
first
;
}
if
(
map_schema_id_fields
.
find
(
category_field
)
==
map_schema_id_fields
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Field "
<<
category_field
<<
" does not exist."
;
return
-
1
;
}
auto
ret
=
ShardIndexGenerator
::
GenerateFieldName
(
std
::
make_pair
(
map_schema_id_fields
[
category_field
],
category_field
));
if
(
SUCCESS
!=
ret
.
first
)
{
...
...
mindspore/ccsrc/mindrecord/meta/shard_category.cc
浏览文件 @
b37db1ed
...
...
@@ -38,7 +38,7 @@ MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
int64_t
ShardCategory
::
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
{
if
(
dataset_size
==
0
)
return
dataset_size
;
if
(
dataset_size
>
0
&&
num_categories_
>
0
&&
num_elements_
>
0
)
{
if
(
dataset_size
>
0
&&
num_c
lasses
>
0
&&
num_c
ategories_
>
0
&&
num_elements_
>
0
)
{
return
std
::
min
(
num_categories_
,
num_classes
)
*
num_elements_
;
}
return
-
1
;
...
...
mindspore/dataset/engine/samplers.py
浏览文件 @
b37db1ed
...
...
@@ -152,6 +152,7 @@ class PKSampler(BuiltinSampler):
num_val (int): Number of elements to sample for each class.
num_class (int, optional): Number of classes to sample (default=None, all classes).
shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
Examples:
>>> import mindspore.dataset as ds
...
...
@@ -168,7 +169,7 @@ class PKSampler(BuiltinSampler):
ValueError: If shuffle is not boolean.
"""
def
__init__
(
self
,
num_val
,
num_class
=
None
,
shuffle
=
False
):
def
__init__
(
self
,
num_val
,
num_class
=
None
,
shuffle
=
False
,
class_column
=
'label'
):
if
num_val
<=
0
:
raise
ValueError
(
"num_val should be a positive integer value, but got num_val={}"
.
format
(
num_val
))
...
...
@@ -180,12 +181,16 @@ class PKSampler(BuiltinSampler):
self
.
num_val
=
num_val
self
.
shuffle
=
shuffle
self
.
class_column
=
class_column
# work for minddataset
def
create
(
self
):
return
cde
.
PKSampler
(
self
.
num_val
,
self
.
shuffle
)
def
_create_for_minddataset
(
self
):
return
cde
.
MindrecordPkSampler
(
self
.
num_val
,
self
.
shuffle
)
if
not
self
.
class_column
or
not
isinstance
(
self
.
class_column
,
str
):
raise
ValueError
(
"class_column should be a not empty string value,
\
but got class_column={}"
.
format
(
class_column
))
return
cde
.
MindrecordPkSampler
(
self
.
num_val
,
self
.
class_column
,
self
.
shuffle
)
class
RandomSampler
(
BuiltinSampler
):
"""
...
...
tests/ut/python/dataset/test_minddataset_exception.py
浏览文件 @
b37db1ed
...
...
@@ -82,3 +82,18 @@ def test_minddataset_lack_db():
num_iter
+=
1
assert
num_iter
==
0
os
.
remove
(
CV_FILE_NAME
)
def
test_cv_minddataset_pk_sample_error_class_column
():
create_cv_mindrecord
(
1
)
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
sampler
=
ds
.
PKSampler
(
5
,
None
,
True
,
'no_exsit_column'
)
with
pytest
.
raises
(
Exception
,
match
=
"MindRecordOp launch failed"
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
sampler
=
sampler
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录