Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7341421d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7341421d
编写于
8月 10, 2020
作者:
L
liyong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix num samples in pk sampler
上级
4276050f
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
133 addition
and
22 deletion
+133
-22
mindspore/ccsrc/minddata/dataset/api/python/bindings/mindrecord/include/bindings.cc
...ataset/api/python/bindings/mindrecord/include/bindings.cc
+3
-3
mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h
...spore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h
+7
-3
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
+1
-0
mindspore/ccsrc/minddata/mindrecord/include/shard_task.h
mindspore/ccsrc/minddata/mindrecord/include/shard_task.h
+2
-1
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
+15
-1
mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc
mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc
+9
-6
mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc
mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc
+8
-1
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+2
-1
mindspore/mindrecord/tools/tfrecord_to_mr.py
mindspore/mindrecord/tools/tfrecord_to_mr.py
+1
-1
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
+3
-3
tests/ut/python/dataset/test_minddataset_sampler.py
tests/ut/python/dataset/test_minddataset_sampler.py
+82
-2
未找到文件。
mindspore/ccsrc/minddata/dataset/api/python/bindings/mindrecord/include/bindings.cc
浏览文件 @
7341421d
...
@@ -48,12 +48,12 @@ PYBIND_REGISTER(
...
@@ -48,12 +48,12 @@ PYBIND_REGISTER(
ShardPkSample
,
1
,
([](
const
py
::
module
*
m
)
{
ShardPkSample
,
1
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
mindrecord
::
ShardPkSample
,
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardPkSample
>>
(
(
void
)
py
::
class_
<
mindrecord
::
ShardPkSample
,
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardPkSample
>>
(
*
m
,
"MindrecordPkSampler"
)
*
m
,
"MindrecordPkSampler"
)
.
def
(
py
::
init
([](
int64_t
kVal
,
std
::
string
kColumn
,
bool
shuffle
)
{
.
def
(
py
::
init
([](
int64_t
kVal
,
std
::
string
kColumn
,
bool
shuffle
,
int64_t
num_samples
)
{
if
(
shuffle
==
true
)
{
if
(
shuffle
==
true
)
{
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
kColumn
,
kVal
,
std
::
numeric_limits
<
int64_t
>::
max
(),
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
kColumn
,
kVal
,
std
::
numeric_limits
<
int64_t
>::
max
(),
GetSeed
());
GetSeed
()
,
num_samples
);
}
else
{
}
else
{
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
kColumn
,
kVal
);
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
kColumn
,
kVal
,
num_samples
);
}
}
}));
}));
}));
}));
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h
浏览文件 @
7341421d
...
@@ -29,19 +29,23 @@ namespace mindspore {
...
@@ -29,19 +29,23 @@ namespace mindspore {
namespace
mindrecord
{
namespace
mindrecord
{
class
ShardPkSample
:
public
ShardCategory
{
class
ShardPkSample
:
public
ShardCategory
{
public:
public:
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
);
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_samples
);
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
);
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
,
int64_t
num_samples
);
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
,
uint32_t
seed
);
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
,
uint32_t
seed
,
int64_t
num_samples
);
~
ShardPkSample
()
override
{};
~
ShardPkSample
()
override
{};
MSRStatus
SufExecute
(
ShardTask
&
tasks
)
override
;
MSRStatus
SufExecute
(
ShardTask
&
tasks
)
override
;
int64_t
GetNumSamples
()
const
{
return
num_samples_
;
}
private:
private:
bool
shuffle_
;
bool
shuffle_
;
std
::
shared_ptr
<
ShardShuffle
>
shuffle_op_
;
std
::
shared_ptr
<
ShardShuffle
>
shuffle_op_
;
int64_t
num_samples_
;
};
};
}
// namespace mindrecord
}
// namespace mindrecord
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
浏览文件 @
7341421d
...
@@ -49,6 +49,7 @@
...
@@ -49,6 +49,7 @@
#include "minddata/mindrecord/include/shard_error.h"
#include "minddata/mindrecord/include/shard_error.h"
#include "minddata/mindrecord/include/shard_index_generator.h"
#include "minddata/mindrecord/include/shard_index_generator.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
#include "minddata/mindrecord/include/shard_reader.h"
#include "minddata/mindrecord/include/shard_reader.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_task.h
浏览文件 @
7341421d
...
@@ -53,7 +53,8 @@ class ShardTask {
...
@@ -53,7 +53,8 @@ class ShardTask {
std
::
tuple
<
TaskType
,
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
GetRandomTask
();
std
::
tuple
<
TaskType
,
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
GetRandomTask
();
static
ShardTask
Combine
(
std
::
vector
<
ShardTask
>
&
category_tasks
,
bool
replacement
,
int64_t
num_elements
);
static
ShardTask
Combine
(
std
::
vector
<
ShardTask
>
&
category_tasks
,
bool
replacement
,
int64_t
num_elements
,
int64_t
num_samples
);
uint32_t
categories
;
uint32_t
categories
;
...
...
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
浏览文件 @
7341421d
...
@@ -827,6 +827,12 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
...
@@ -827,6 +827,12 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
std
::
string
category_field
=
category_op
->
GetCategoryField
();
std
::
string
category_field
=
category_op
->
GetCategoryField
();
auto
num_classes
=
GetNumClasses
(
category_field
);
auto
num_classes
=
GetNumClasses
(
category_field
);
num_samples
=
category_op
->
GetNumSamples
(
num_samples
,
num_classes
);
num_samples
=
category_op
->
GetNumSamples
(
num_samples
,
num_classes
);
if
(
std
::
dynamic_pointer_cast
<
ShardPkSample
>
(
op
))
{
auto
tmp
=
std
::
dynamic_pointer_cast
<
ShardPkSample
>
(
op
)
->
GetNumSamples
();
if
(
tmp
!=
0
)
{
num_samples
=
std
::
min
(
num_samples
,
tmp
);
}
}
}
else
if
(
std
::
dynamic_pointer_cast
<
ShardSample
>
(
op
))
{
}
else
if
(
std
::
dynamic_pointer_cast
<
ShardSample
>
(
op
))
{
if
(
std
::
dynamic_pointer_cast
<
ShardDistributedSample
>
(
op
))
{
if
(
std
::
dynamic_pointer_cast
<
ShardDistributedSample
>
(
op
))
{
auto
sampler_op
=
std
::
dynamic_pointer_cast
<
ShardDistributedSample
>
(
op
);
auto
sampler_op
=
std
::
dynamic_pointer_cast
<
ShardDistributedSample
>
(
op
);
...
@@ -958,6 +964,14 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
...
@@ -958,6 +964,14 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
auto
category_op
=
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
);
auto
category_op
=
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
);
auto
categories
=
category_op
->
GetCategories
();
auto
categories
=
category_op
->
GetCategories
();
int64_t
num_elements
=
category_op
->
GetNumElements
();
int64_t
num_elements
=
category_op
->
GetNumElements
();
int64_t
num_samples
=
0
;
if
(
std
::
dynamic_pointer_cast
<
ShardPkSample
>
(
op
))
{
num_samples
=
std
::
dynamic_pointer_cast
<
ShardPkSample
>
(
op
)
->
GetNumSamples
();
if
(
num_samples
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Parameter num_samples is not positive or zero"
;
return
FAILED
;
}
}
if
(
num_elements
<=
0
)
{
if
(
num_elements
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Parameter num_element is not positive"
;
MS_LOG
(
ERROR
)
<<
"Parameter num_element is not positive"
;
return
FAILED
;
return
FAILED
;
...
@@ -1006,7 +1020,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
...
@@ -1006,7 +1020,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
}
}
MS_LOG
(
INFO
)
<<
"Category #"
<<
categoryNo
<<
" has "
<<
categoryTasks
[
categoryNo
].
Size
()
<<
" tasks"
;
MS_LOG
(
INFO
)
<<
"Category #"
<<
categoryNo
<<
" has "
<<
categoryTasks
[
categoryNo
].
Size
()
<<
" tasks"
;
}
}
tasks_
=
ShardTask
::
Combine
(
categoryTasks
,
category_op
->
GetReplacement
(),
num_elements
);
tasks_
=
ShardTask
::
Combine
(
categoryTasks
,
category_op
->
GetReplacement
(),
num_elements
,
num_samples
);
if
(
SUCCESS
!=
(
*
category_op
)(
tasks_
))
{
if
(
SUCCESS
!=
(
*
category_op
)(
tasks_
))
{
return
FAILED
;
return
FAILED
;
}
}
...
...
mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc
浏览文件 @
7341421d
...
@@ -22,15 +22,18 @@ using mindspore::MsLogLevel::ERROR;
...
@@ -22,15 +22,18 @@ using mindspore::MsLogLevel::ERROR;
namespace
mindspore
{
namespace
mindspore
{
namespace
mindrecord
{
namespace
mindrecord
{
ShardPkSample
::
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
)
ShardPkSample
::
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_samples
)
:
ShardCategory
(
category_field
,
num_elements
,
std
::
numeric_limits
<
int64_t
>::
max
(),
true
),
shuffle_
(
false
)
{}
:
ShardCategory
(
category_field
,
num_elements
,
std
::
numeric_limits
<
int64_t
>::
max
(),
true
),
shuffle_
(
false
),
num_samples_
(
num_samples
)
{}
ShardPkSample
::
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
)
ShardPkSample
::
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
,
:
ShardCategory
(
category_field
,
num_elements
,
num_categories
,
true
),
shuffle_
(
false
)
{}
int64_t
num_samples
)
:
ShardCategory
(
category_field
,
num_elements
,
num_categories
,
true
),
shuffle_
(
false
),
num_samples_
(
num_samples
)
{}
ShardPkSample
::
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
,
ShardPkSample
::
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
,
uint32_t
seed
)
uint32_t
seed
,
int64_t
num_samples
)
:
ShardCategory
(
category_field
,
num_elements
,
num_categories
,
true
),
shuffle_
(
true
)
{
:
ShardCategory
(
category_field
,
num_elements
,
num_categories
,
true
),
shuffle_
(
true
)
,
num_samples_
(
num_samples
)
{
shuffle_op_
=
std
::
make_shared
<
ShardShuffle
>
(
seed
,
kShuffleSample
);
// do shuffle and replacement
shuffle_op_
=
std
::
make_shared
<
ShardShuffle
>
(
seed
,
kShuffleSample
);
// do shuffle and replacement
}
}
...
...
mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc
浏览文件 @
7341421d
...
@@ -86,7 +86,8 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa
...
@@ -86,7 +86,8 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa
return
task_list_
[
dis
(
gen
)];
return
task_list_
[
dis
(
gen
)];
}
}
ShardTask
ShardTask
::
Combine
(
std
::
vector
<
ShardTask
>
&
category_tasks
,
bool
replacement
,
int64_t
num_elements
)
{
ShardTask
ShardTask
::
Combine
(
std
::
vector
<
ShardTask
>
&
category_tasks
,
bool
replacement
,
int64_t
num_elements
,
int64_t
num_samples
)
{
ShardTask
res
;
ShardTask
res
;
if
(
category_tasks
.
empty
())
return
res
;
if
(
category_tasks
.
empty
())
return
res
;
auto
total_categories
=
category_tasks
.
size
();
auto
total_categories
=
category_tasks
.
size
();
...
@@ -96,9 +97,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
...
@@ -96,9 +97,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
for
(
uint32_t
i
=
1
;
i
<
total_categories
;
i
++
)
{
for
(
uint32_t
i
=
1
;
i
<
total_categories
;
i
++
)
{
minTasks
=
std
::
min
(
minTasks
,
category_tasks
[
i
].
Size
());
minTasks
=
std
::
min
(
minTasks
,
category_tasks
[
i
].
Size
());
}
}
int64_t
count
=
0
;
for
(
uint32_t
task_no
=
0
;
task_no
<
minTasks
;
task_no
++
)
{
for
(
uint32_t
task_no
=
0
;
task_no
<
minTasks
;
task_no
++
)
{
for
(
uint32_t
i
=
0
;
i
<
total_categories
;
i
++
)
{
for
(
uint32_t
i
=
0
;
i
<
total_categories
;
i
++
)
{
if
(
num_samples
!=
0
&&
count
==
num_samples
)
break
;
res
.
InsertTask
(
std
::
move
(
category_tasks
[
i
].
GetTaskByID
(
static_cast
<
int
>
(
task_no
))));
res
.
InsertTask
(
std
::
move
(
category_tasks
[
i
].
GetTaskByID
(
static_cast
<
int
>
(
task_no
))));
count
++
;
}
}
}
}
}
else
{
}
else
{
...
@@ -109,9 +113,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
...
@@ -109,9 +113,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
if
(
num_elements
!=
std
::
numeric_limits
<
int64_t
>::
max
())
{
if
(
num_elements
!=
std
::
numeric_limits
<
int64_t
>::
max
())
{
maxTasks
=
static_cast
<
decltype
(
maxTasks
)
>
(
num_elements
);
maxTasks
=
static_cast
<
decltype
(
maxTasks
)
>
(
num_elements
);
}
}
int64_t
count
=
0
;
for
(
uint32_t
i
=
0
;
i
<
total_categories
;
i
++
)
{
for
(
uint32_t
i
=
0
;
i
<
total_categories
;
i
++
)
{
for
(
uint32_t
j
=
0
;
j
<
maxTasks
;
j
++
)
{
for
(
uint32_t
j
=
0
;
j
<
maxTasks
;
j
++
)
{
if
(
num_samples
!=
0
&&
count
==
num_samples
)
break
;
res
.
InsertTask
(
category_tasks
[
i
].
GetRandomTask
());
res
.
InsertTask
(
category_tasks
[
i
].
GetRandomTask
());
count
++
;
}
}
}
}
}
}
...
...
mindspore/dataset/engine/samplers.py
浏览文件 @
7341421d
...
@@ -359,7 +359,8 @@ class PKSampler(BuiltinSampler):
...
@@ -359,7 +359,8 @@ class PKSampler(BuiltinSampler):
if
not
self
.
class_column
or
not
isinstance
(
self
.
class_column
,
str
):
if
not
self
.
class_column
or
not
isinstance
(
self
.
class_column
,
str
):
raise
ValueError
(
"class_column should be a not empty string value,
\
raise
ValueError
(
"class_column should be a not empty string value,
\
but got class_column={}"
.
format
(
class_column
))
but got class_column={}"
.
format
(
class_column
))
c_sampler
=
cde
.
MindrecordPkSampler
(
self
.
num_val
,
self
.
class_column
,
self
.
shuffle
)
num_samples
=
self
.
num_samples
if
self
.
num_samples
is
not
None
else
0
c_sampler
=
cde
.
MindrecordPkSampler
(
self
.
num_val
,
self
.
class_column
,
self
.
shuffle
,
num_samples
)
c_child_sampler
=
self
.
create_child_for_minddataset
()
c_child_sampler
=
self
.
create_child_for_minddataset
()
c_sampler
.
add_child
(
c_child_sampler
)
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
return
c_sampler
...
...
mindspore/mindrecord/tools/tfrecord_to_mr.py
浏览文件 @
7341421d
...
@@ -104,7 +104,7 @@ class TFRecordToMR:
...
@@ -104,7 +104,7 @@ class TFRecordToMR:
Args:
Args:
source (str): the TFRecord file to be transformed.
source (str): the TFRecord file to be transformed.
destination (str): the MindRecord file path to tranform into.
destination (str): the MindRecord file path to tranform into.
feature_dict (dict): a dictionary tha
n
states the feature type, i.e.
feature_dict (dict): a dictionary tha
t
states the feature type, i.e.
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string),
\
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string),
\
"yyyy": tf.io.FixedLenFeature([], tf.int64)}
"yyyy": tf.io.FixedLenFeature([], tf.int64)}
...
...
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
浏览文件 @
7341421d
...
@@ -162,7 +162,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
...
@@ -162,7 +162,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
,
"label"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
,
"label"
};
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
));
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
,
0
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
({
file_name
},
true
,
4
,
column_list
,
ops
);
dataset
.
Open
({
file_name
},
true
,
4
,
column_list
,
ops
);
...
@@ -187,7 +187,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
...
@@ -187,7 +187,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
,
"label"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
,
"label"
};
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
,
3
,
0
));
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
,
3
,
0
,
0
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
({
file_name
},
true
,
4
,
column_list
,
ops
);
dataset
.
Open
({
file_name
},
true
,
4
,
column_list
,
ops
);
...
@@ -204,7 +204,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
...
@@ -204,7 +204,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
}
}
dataset
.
Finish
();
dataset
.
Finish
();
ASSERT_TRUE
(
i
==
6
);
ASSERT_TRUE
(
i
==
6
);
}
// namespace mindrecord
}
TEST_F
(
TestShardOperator
,
TestShardCategory
)
{
TEST_F
(
TestShardOperator
,
TestShardCategory
)
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test read imageNet"
));
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test read imageNet"
));
...
...
tests/ut/python/dataset/test_minddataset_sampler.py
浏览文件 @
7341421d
...
@@ -101,7 +101,6 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
...
@@ -101,7 +101,6 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
num_iter
+=
1
def
test_cv_minddataset_pk_sample_shuffle
(
add_and_remove_cv_file
):
def
test_cv_minddataset_pk_sample_shuffle
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
...
@@ -120,9 +119,51 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
...
@@ -120,9 +119,51 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
logger
.
info
(
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
9
def
test_cv_minddataset_pk_sample_shuffle_1
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
sampler
=
ds
.
PKSampler
(
3
,
None
,
True
,
'label'
,
5
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
assert
data_set
.
get_dataset_size
()
==
5
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[file_name]:
\
{}------------------------"
.
format
(
to_str
(
item
[
"file_name"
])))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
==
5
def
test_cv_minddataset_pk_sample_shuffle_2
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
sampler
=
ds
.
PKSampler
(
3
,
None
,
True
,
'label'
,
10
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
assert
data_set
.
get_dataset_size
()
==
9
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[file_name]:
\
{}------------------------"
.
format
(
to_str
(
item
[
"file_name"
])))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
==
9
def
test_cv_minddataset_pk_sample_out_of_range
(
add_and_remove_cv_file
):
def
test_cv_minddataset_pk_sample_out_of_range
_0
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
num_readers
=
4
...
@@ -139,6 +180,45 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
...
@@ -139,6 +180,45 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
logger
.
info
(
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
num_iter
+=
1
assert
num_iter
==
15
def
test_cv_minddataset_pk_sample_out_of_range_1
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
sampler
=
ds
.
PKSampler
(
5
,
None
,
True
,
'label'
,
20
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
assert
data_set
.
get_dataset_size
()
==
15
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[file_name]:
\
{}------------------------"
.
format
(
to_str
(
item
[
"file_name"
])))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
==
15
def
test_cv_minddataset_pk_sample_out_of_range_2
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
sampler
=
ds
.
PKSampler
(
5
,
None
,
True
,
'label'
,
10
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
assert
data_set
.
get_dataset_size
()
==
10
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[file_name]:
\
{}------------------------"
.
format
(
to_str
(
item
[
"file_name"
])))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
==
10
def
test_cv_minddataset_subset_random_sample_basic
(
add_and_remove_cv_file
):
def
test_cv_minddataset_subset_random_sample_basic
(
add_and_remove_cv_file
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录