Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f1fa2a99
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f1fa2a99
编写于
4月 14, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
!273 [MD] update subset random sampler in minddataset
Merge pull request !273 from liyong126/mindrecord_subset_sampler_python
上级
511acd29
0ce83e39
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
115 addition
and
107 deletion
+115
-107
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+4
-30
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+0
-3
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+8
-0
mindspore/ccsrc/mindrecord/include/shard_category.h
mindspore/ccsrc/mindrecord/include/shard_category.h
+1
-1
mindspore/ccsrc/mindrecord/include/shard_operator.h
mindspore/ccsrc/mindrecord/include/shard_operator.h
+19
-1
mindspore/ccsrc/mindrecord/include/shard_sample.h
mindspore/ccsrc/mindrecord/include/shard_sample.h
+7
-3
mindspore/ccsrc/mindrecord/include/shard_shuffle.h
mindspore/ccsrc/mindrecord/include/shard_shuffle.h
+1
-1
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+9
-2
mindspore/ccsrc/mindrecord/meta/shard_category.cc
mindspore/ccsrc/mindrecord/meta/shard_category.cc
+1
-1
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
+14
-3
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
+1
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+1
-3
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+2
-0
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
+47
-48
tests/ut/python/dataset/test_minddataset_sampler.py
tests/ut/python/dataset/test_minddataset_sampler.py
+0
-10
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
f1fa2a99
...
...
@@ -391,30 +391,6 @@ Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vecto
return
Status
::
OK
();
}
Status
DEPipeline
::
GetMindrecordSampler
(
const
std
::
string
&
sampler_name
,
const
py
::
dict
&
args
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
*
ptr
)
{
std
::
vector
<
int
>
indices
;
for
(
auto
&
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"indices"
)
{
indices
=
ToIntVector
(
value
);
}
else
{
std
::
string
err_msg
=
"ERROR: parameter "
+
key
+
" is invalid."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
}
if
(
sampler_name
==
"SubsetRandomSampler"
)
{
*
ptr
=
std
::
make_shared
<
mindrecord
::
ShardSample
>
(
indices
);
}
else
{
std
::
string
err_msg
=
"ERROR: parameter sampler_name is invalid."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseMindRecordOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
if
(
args
[
"dataset_file"
].
is_none
())
{
std
::
string
err_msg
=
"Error: at least one of dataset_files is missing"
;
...
...
@@ -446,12 +422,10 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
}
else
if
(
key
==
"global_shuffle"
&&
ToBool
(
value
)
==
true
)
{
uint32_t
seed
=
args
[
"partitions"
].
is_none
()
?
GetSeed
()
:
0
;
operators
.
push_back
(
std
::
make_shared
<
mindrecord
::
ShardShuffle
>
(
seed
));
}
else
if
(
key
==
"sampler_name"
)
{
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
sample_op
;
auto
ret
=
GetMindrecordSampler
(
ToString
(
value
),
args
[
"sampler_params"
],
&
sample_op
);
if
(
Status
::
OK
()
!=
ret
)
{
return
ret
;
}
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"_create_for_minddataset"
);
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
sample_op
=
create
().
cast
<
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
();
operators
.
push_back
(
sample_op
);
}
}
...
...
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
f1fa2a99
...
...
@@ -145,9 +145,6 @@ class DEPipeline {
Status
ParseCelebAOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
GetMindrecordSampler
(
const
std
::
string
&
sampler_name
,
const
py
::
dict
&
args
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
*
ptr
);
private:
// Execution tree that links the dataset operators.
std
::
shared_ptr
<
ExecutionTree
>
tree_
;
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
f1fa2a99
...
...
@@ -54,6 +54,9 @@
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_sample.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
...
...
@@ -382,6 +385,7 @@ void bindTensorOps4(py::module *m) {
void
bindSamplerOps
(
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
Sampler
,
std
::
shared_ptr
<
Sampler
>>
(
*
m
,
"Sampler"
);
(
void
)
py
::
class_
<
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
(
*
m
,
"ShardOperator"
);
(
void
)
py
::
class_
<
DistributedSampler
,
Sampler
,
std
::
shared_ptr
<
DistributedSampler
>>
(
*
m
,
"DistributedSampler"
)
.
def
(
py
::
init
<
int64_t
,
int64_t
,
bool
,
uint32_t
>
(),
py
::
arg
(
"numDev"
),
py
::
arg
(
"devId"
),
py
::
arg
(
"shuffle"
),
...
...
@@ -399,6 +403,10 @@ void bindSamplerOps(py::module *m) {
(
void
)
py
::
class_
<
SubsetRandomSampler
,
Sampler
,
std
::
shared_ptr
<
SubsetRandomSampler
>>
(
*
m
,
"SubsetRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
int64_t
>>
(),
py
::
arg
(
"indices"
));
(
void
)
py
::
class_
<
mindrecord
::
ShardSample
,
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardSample
>>
(
*
m
,
"MindrecordSubsetRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
int64_t
>
,
uint32_t
>
(),
py
::
arg
(
"indices"
),
py
::
arg
(
"seed"
)
=
GetSeed
());
(
void
)
py
::
class_
<
WeightedRandomSampler
,
Sampler
,
std
::
shared_ptr
<
WeightedRandomSampler
>>
(
*
m
,
"WeightedRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
double
>
,
int64_t
,
bool
>
(),
py
::
arg
(
"weights"
),
py
::
arg
(
"numSamples"
),
py
::
arg
(
"replacement"
));
...
...
mindspore/ccsrc/mindrecord/include/shard_category.h
浏览文件 @
f1fa2a99
...
...
@@ -32,7 +32,7 @@ class ShardCategory : public ShardOperator {
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
get_categories
()
const
;
MSRStatus
operator
()
(
ShardTask
&
tasks
)
override
;
MSRStatus
execute
(
ShardTask
&
tasks
)
override
;
private:
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
categories_
;
...
...
mindspore/ccsrc/mindrecord/include/shard_operator.h
浏览文件 @
f1fa2a99
...
...
@@ -24,7 +24,25 @@ namespace mindrecord {
class
ShardOperator
{
public:
virtual
~
ShardOperator
()
=
default
;
virtual
MSRStatus
operator
()(
ShardTask
&
tasks
)
=
0
;
MSRStatus
operator
()(
ShardTask
&
tasks
)
{
if
(
SUCCESS
!=
this
->
pre_execute
(
tasks
))
{
return
FAILED
;
}
if
(
SUCCESS
!=
this
->
execute
(
tasks
))
{
return
FAILED
;
}
if
(
SUCCESS
!=
this
->
suf_execute
(
tasks
))
{
return
FAILED
;
}
return
SUCCESS
;
}
virtual
MSRStatus
pre_execute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
virtual
MSRStatus
execute
(
ShardTask
&
tasks
)
=
0
;
virtual
MSRStatus
suf_execute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
};
}
// namespace mindrecord
}
// namespace mindspore
...
...
mindspore/ccsrc/mindrecord/include/shard_sample.h
浏览文件 @
f1fa2a99
...
...
@@ -17,10 +17,12 @@
#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_shuffle.h"
namespace
mindspore
{
namespace
mindrecord
{
...
...
@@ -32,21 +34,23 @@ class ShardSample : public ShardOperator {
ShardSample
(
int
num
,
int
den
,
int
par
);
explicit
ShardSample
(
const
std
::
vector
<
int
>
&
indices
);
ShardSample
(
const
std
::
vector
<
int64_t
>
&
indices
,
uint32_t
seed
);
~
ShardSample
()
override
{};
const
std
::
pair
<
int
,
int
>
get_partitions
()
const
;
MSRStatus
operator
()(
ShardTask
&
tasks
)
override
;
MSRStatus
execute
(
ShardTask
&
tasks
)
override
;
MSRStatus
suf_execute
(
ShardTask
&
tasks
)
override
;
private:
int
numerator_
;
int
denominator_
;
int
no_of_samples_
;
int
partition_id_
;
std
::
vector
<
int
>
indices_
;
std
::
vector
<
int
64_t
>
indices_
;
SamplerType
sampler_type_
;
std
::
shared_ptr
<
ShardShuffle
>
shuffle_op_
;
};
}
// namespace mindrecord
}
// namespace mindspore
...
...
mindspore/ccsrc/mindrecord/include/shard_shuffle.h
浏览文件 @
f1fa2a99
...
...
@@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator {
~
ShardShuffle
()
override
{};
MSRStatus
operator
()
(
ShardTask
&
tasks
)
override
;
MSRStatus
execute
(
ShardTask
&
tasks
)
override
;
private:
uint32_t
shuffle_seed_
;
...
...
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
f1fa2a99
...
...
@@ -779,8 +779,12 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
// Sort row group by (group_id, shard_id), prepare for parallel reading
std
::
sort
(
row_group_summary
.
begin
(),
row_group_summary
.
end
(),
ResortRowGroups
);
CreateTasks
(
row_group_summary
,
operators_
);
MS_LOG
(
INFO
)
<<
"Launching read threads"
;
if
(
CreateTasks
(
row_group_summary
,
operators_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Failed to launch read threads."
;
interrupt_
=
true
;
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
"Launching read threads."
;
if
(
isSimpleReader
)
return
SUCCESS
;
...
...
@@ -1152,6 +1156,9 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetBlockNext()
}
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
ShardReader
::
GetNext
()
{
if
(
interrupt_
)
{
return
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
();
}
if
(
block_reader_
)
return
GetBlockNext
();
if
(
deliver_id_
>=
static_cast
<
int
>
(
tasks_
.
Size
()))
{
return
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
();
...
...
mindspore/ccsrc/mindrecord/meta/shard_category.cc
浏览文件 @
f1fa2a99
...
...
@@ -23,6 +23,6 @@ ShardCategory::ShardCategory(const std::vector<std::pair<std::string, std::strin
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
ShardCategory
::
get_categories
()
const
{
return
categories_
;
}
MSRStatus
ShardCategory
::
operator
()
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
MSRStatus
ShardCategory
::
execute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
}
// namespace mindrecord
}
// namespace mindspore
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
浏览文件 @
f1fa2a99
...
...
@@ -46,13 +46,15 @@ ShardSample::ShardSample(int num, int den, int par)
indices_
({}),
sampler_type_
(
kCustomTopPercentSampler
)
{}
ShardSample
::
ShardSample
(
const
std
::
vector
<
int
>
&
indices
)
ShardSample
::
ShardSample
(
const
std
::
vector
<
int
64_t
>
&
indices
,
uint32_t
seed
)
:
numerator_
(
0
),
denominator_
(
0
),
no_of_samples_
(
0
),
partition_id_
(
0
),
indices_
(
indices
),
sampler_type_
(
kSubsetRandomSampler
)
{}
sampler_type_
(
kSubsetRandomSampler
)
{
shuffle_op_
=
std
::
make_shared
<
ShardShuffle
>
(
seed
);
}
const
std
::
pair
<
int
,
int
>
ShardSample
::
get_partitions
()
const
{
if
(
numerator_
==
1
&&
denominator_
>
1
)
{
...
...
@@ -61,7 +63,7 @@ const std::pair<int, int> ShardSample::get_partitions() const {
return
std
::
pair
<
int
,
int
>
(
-
1
,
-
1
);
}
MSRStatus
ShardSample
::
operator
()
(
ShardTask
&
tasks
)
{
MSRStatus
ShardSample
::
execute
(
ShardTask
&
tasks
)
{
int
no_of_categories
=
static_cast
<
int
>
(
tasks
.
categories
);
int
total_no
=
static_cast
<
int
>
(
tasks
.
Size
());
...
...
@@ -115,5 +117,14 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) {
}
return
SUCCESS
;
}
MSRStatus
ShardSample
::
suf_execute
(
ShardTask
&
tasks
)
{
if
(
sampler_type_
==
kSubsetRandomSampler
)
{
if
(
SUCCESS
!=
(
*
shuffle_op_
)(
tasks
))
{
return
FAILED
;
}
}
return
SUCCESS
;
}
}
// namespace mindrecord
}
// namespace mindspore
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
浏览文件 @
f1fa2a99
...
...
@@ -22,7 +22,7 @@ namespace mindspore {
namespace
mindrecord
{
ShardShuffle
::
ShardShuffle
(
uint32_t
seed
)
:
shuffle_seed_
(
seed
)
{}
MSRStatus
ShardShuffle
::
operator
()
(
ShardTask
&
tasks
)
{
MSRStatus
ShardShuffle
::
execute
(
ShardTask
&
tasks
)
{
if
(
tasks
.
categories
<
1
)
{
return
FAILED
;
}
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
f1fa2a99
...
...
@@ -1683,9 +1683,7 @@ class MindDataset(SourceDataset):
args
[
"block_reader"
]
=
self
.
block_reader
args
[
"num_shards"
]
=
self
.
num_shards
args
[
"shard_id"
]
=
self
.
shard_id
if
self
.
sampler
:
args
[
"sampler_name"
]
=
self
.
sampler
.
__class__
.
__name__
args
[
"sampler_params"
]
=
self
.
sampler
.
__dict__
args
[
"sampler"
]
=
self
.
sampler
return
args
def
get_dataset_size
(
self
):
...
...
mindspore/dataset/engine/samplers.py
浏览文件 @
f1fa2a99
...
...
@@ -195,6 +195,8 @@ class SubsetRandomSampler():
def
create
(
self
):
return
cde
.
SubsetRandomSampler
(
self
.
indices
)
def
_create_for_minddataset
(
self
):
return
cde
.
MindrecordSubsetRandomSampler
(
self
.
indices
)
class
WeightedRandomSampler
():
"""
...
...
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
浏览文件 @
f1fa2a99
...
...
@@ -30,9 +30,9 @@
#include "mindrecord/include/shard_shuffle.h"
#include "ut_common.h"
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
MsLogLevel
::
INFO
;
namespace
mindspore
{
namespace
mindrecord
{
...
...
@@ -65,31 +65,31 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
ASSERT_TRUE
(
i
<=
kSampleCount
);
}
//
TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
//
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
//
//
std::string file_name = "./imagenet.shard01";
//
auto column_list = std::vector<std::string>{"file_name"};
//
//
const int kNum = 5;
//
const int kDen = 0;
//
std::vector<std::shared_ptr<ShardOperator>> ops;
//
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
//
//
ShardReader dataset;
//
dataset.Open(file_name, 4, column_list, ops);
//
dataset.Launch();
//
//
int i = 0;
//
while (true) {
//
auto x = dataset.GetNext();
//
if (x.empty()) break;
//
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
//
i++;
//
}
//
dataset.Finish();
//
ASSERT_TRUE(i <= 5);
//
}
TEST_F
(
TestShardOperator
,
TestShardSampleWrongNumber
)
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test read imageNet"
));
std
::
string
file_name
=
"./imagenet.shard01"
;
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
};
const
int
kNum
=
5
;
const
int
kDen
=
0
;
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kNum
,
kDen
));
ShardReader
dataset
;
dataset
.
Open
(
file_name
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
int
i
=
0
;
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
]);
i
++
;
}
dataset
.
Finish
();
ASSERT_TRUE
(
i
<=
5
);
}
TEST_F
(
TestShardOperator
,
TestShardSampleRatio
)
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test read imageNet"
));
...
...
@@ -117,7 +117,6 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
ASSERT_TRUE
(
i
<=
10
);
}
TEST_F
(
TestShardOperator
,
TestShardSamplePartition
)
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test read imageNet"
));
std
::
string
file_name
=
"./imagenet.shard01"
;
...
...
@@ -170,8 +169,8 @@ TEST_F(TestShardOperator, TestShardCategory) {
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
ASSERT_TRUE
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
]
==
categories
[
category_no
].
second
);
...
...
@@ -199,8 +198,8 @@ TEST_F(TestShardOperator, TestShardShuffle) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
}
dataset
.
Finish
();
...
...
@@ -224,8 +223,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
}
dataset
.
Finish
();
...
...
@@ -251,8 +250,8 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
}
dataset
.
Finish
();
...
...
@@ -278,8 +277,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
}
dataset
.
Finish
();
...
...
@@ -307,8 +306,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
auto
y
=
compare_dataset
.
GetNext
();
...
...
@@ -342,8 +341,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
ASSERT_TRUE
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
]
==
categories
[
category_no
].
second
);
...
...
@@ -376,8 +375,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
ASSERT_TRUE
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
]
==
categories
[
category_no
].
second
);
category_no
++
;
...
...
@@ -410,8 +409,8 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
ASSERT_TRUE
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
]
==
categories
[
category_no
].
second
);
...
...
@@ -448,8 +447,8 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
ASSERT_TRUE
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
]
==
categories
[
category_no
].
second
);
...
...
tests/ut/python/dataset/test_minddataset_sampler.py
浏览文件 @
f1fa2a99
...
...
@@ -81,8 +81,6 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]][
'file_name'
]
==
""
.
join
(
[
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
5
...
...
@@ -107,8 +105,6 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]][
'file_name'
]
==
""
.
join
(
[
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
6
...
...
@@ -133,8 +129,6 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]][
'file_name'
]
==
""
.
join
(
[
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
0
...
...
@@ -159,8 +153,6 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]
%
len
(
data
)][
'file_name'
]
==
""
.
join
([
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
5
...
...
@@ -185,8 +177,6 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]
%
len
(
data
)][
'file_name'
]
==
""
.
join
([
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
5
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录