Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
beefb20c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
beefb20c
编写于
6月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1983 Remove inheritance of Sampler from DatasetOp
Merge pull request !1983 from JesseKLee/sampler
上级
5ffb0040
255adf7c
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
77 addition
and
115 deletion
+77
-115
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
...spore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
...ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
...ore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
...t/engine/datasetops/source/sampler/distributed_sampler.cc
+2
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
...et/engine/datasetops/source/sampler/distributed_sampler.h
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
...rc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
+2
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
...src/dataset/engine/datasetops/source/sampler/pk_sampler.h
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
...ataset/engine/datasetops/source/sampler/python_sampler.cc
+2
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
...dataset/engine/datasetops/source/sampler/python_sampler.h
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
...ataset/engine/datasetops/source/sampler/random_sampler.cc
+2
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
...dataset/engine/datasetops/source/sampler/random_sampler.h
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
...ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
+4
-27
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
.../ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
+6
-21
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
...et/engine/datasetops/source/sampler/sequential_sampler.cc
+2
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
...set/engine/datasetops/source/sampler/sequential_sampler.h
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
...engine/datasetops/source/sampler/subset_random_sampler.cc
+2
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
.../engine/datasetops/source/sampler/subset_random_sampler.h
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
...gine/datasetops/source/sampler/weighted_random_sampler.cc
+2
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
...ngine/datasetops/source/sampler/weighted_random_sampler.h
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
+3
-3
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
+5
-5
tests/ut/cpp/dataset/subset_random_sampler_test.cc
tests/ut/cpp/dataset/subset_random_sampler_test.cc
+7
-7
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
+16
-16
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
浏览文件 @
beefb20c
...
@@ -263,7 +263,7 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) {
...
@@ -263,7 +263,7 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) {
Status
CelebAOp
::
operator
()()
{
Status
CelebAOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
std
::
unique_ptr
<
DataBuffer
>
data_buffer
;
std
::
unique_ptr
<
DataBuffer
>
data_buffer
;
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
data_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
data_buffer
));
RETURN_IF_NOT_OK
(
AddIOBlock
(
&
data_buffer
));
RETURN_IF_NOT_OK
(
AddIOBlock
(
&
data_buffer
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -291,7 +291,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
...
@@ -291,7 +291,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
keys
.
clear
();
keys
.
clear
();
}
}
}
}
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
data_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
data_buffer
));
}
}
if
(
!
keys
.
empty
())
{
if
(
!
keys
.
empty
())
{
...
@@ -313,7 +313,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
...
@@ -313,7 +313,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
io_block_queues_
[(
buff_count
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
io_block_queues_
[(
buff_count
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
wp_
.
Clear
();
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
data_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
data_buffer
));
}
}
}
}
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
浏览文件 @
beefb20c
...
@@ -100,7 +100,7 @@ CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const
...
@@ -100,7 +100,7 @@ CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const
Status
CifarOp
::
operator
()()
{
Status
CifarOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
while
(
true
)
{
// each iterator is 1 epoch
while
(
true
)
{
// each iterator is 1 epoch
std
::
vector
<
int64_t
>
keys
;
std
::
vector
<
int64_t
>
keys
;
keys
.
reserve
(
rows_per_buffer_
);
keys
.
reserve
(
rows_per_buffer_
);
...
@@ -118,7 +118,7 @@ Status CifarOp::operator()() {
...
@@ -118,7 +118,7 @@ Status CifarOp::operator()() {
keys
.
clear
();
keys
.
clear
();
}
}
}
}
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
}
}
if
(
keys
.
empty
()
==
false
)
{
if
(
keys
.
empty
()
==
false
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
...
@@ -139,7 +139,7 @@ Status CifarOp::operator()() {
...
@@ -139,7 +139,7 @@ Status CifarOp::operator()() {
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
wp_
.
Clear
();
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
}
}
}
}
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
浏览文件 @
beefb20c
...
@@ -126,7 +126,7 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
...
@@ -126,7 +126,7 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
Status
ImageFolderOp
::
operator
()()
{
Status
ImageFolderOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
while
(
true
)
{
// each iterator is 1 epoch
while
(
true
)
{
// each iterator is 1 epoch
std
::
vector
<
int64_t
>
keys
;
std
::
vector
<
int64_t
>
keys
;
keys
.
reserve
(
rows_per_buffer_
);
keys
.
reserve
(
rows_per_buffer_
);
...
@@ -145,7 +145,7 @@ Status ImageFolderOp::operator()() {
...
@@ -145,7 +145,7 @@ Status ImageFolderOp::operator()() {
keys
.
clear
();
keys
.
clear
();
}
}
}
}
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
}
}
if
(
keys
.
empty
()
==
false
)
{
if
(
keys
.
empty
()
==
false
)
{
RETURN_IF_NOT_OK
(
RETURN_IF_NOT_OK
(
...
@@ -166,7 +166,7 @@ Status ImageFolderOp::operator()() {
...
@@ -166,7 +166,7 @@ Status ImageFolderOp::operator()() {
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
wp_
.
Clear
();
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
}
}
}
}
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
浏览文件 @
beefb20c
...
@@ -88,7 +88,7 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
...
@@ -88,7 +88,7 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
Status
ManifestOp
::
operator
()()
{
Status
ManifestOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
return
AddIoBlock
(
&
sampler_buffer
);
return
AddIoBlock
(
&
sampler_buffer
);
}
}
...
@@ -110,7 +110,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
...
@@ -110,7 +110,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
keys
.
clear
();
keys
.
clear
();
}
}
}
}
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
sampler_buffer
));
}
}
if
(
keys
.
empty
()
==
false
)
{
if
(
keys
.
empty
()
==
false
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
...
@@ -131,7 +131,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
...
@@ -131,7 +131,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
wp_
.
Clear
();
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
sampler_buffer
));
}
}
}
}
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
浏览文件 @
beefb20c
...
@@ -98,7 +98,7 @@ Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, st
...
@@ -98,7 +98,7 @@ Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, st
Status
MnistOp
::
operator
()()
{
Status
MnistOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
while
(
true
)
{
// each iterator is 1 epoch
while
(
true
)
{
// each iterator is 1 epoch
std
::
vector
<
int64_t
>
keys
;
std
::
vector
<
int64_t
>
keys
;
keys
.
reserve
(
rows_per_buffer_
);
keys
.
reserve
(
rows_per_buffer_
);
...
@@ -109,7 +109,7 @@ Status MnistOp::operator()() {
...
@@ -109,7 +109,7 @@ Status MnistOp::operator()() {
RETURN_STATUS_UNEXPECTED
(
"Sampler Tensor isn't UINT64"
);
RETURN_STATUS_UNEXPECTED
(
"Sampler Tensor isn't UINT64"
);
}
}
RETURN_IF_NOT_OK
(
TraversalSampleIds
(
sample_ids
,
&
keys
));
RETURN_IF_NOT_OK
(
TraversalSampleIds
(
sample_ids
,
&
keys
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
}
}
if
(
keys
.
empty
()
==
false
)
{
if
(
keys
.
empty
()
==
false
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
...
@@ -130,7 +130,7 @@ Status MnistOp::operator()() {
...
@@ -130,7 +130,7 @@ Status MnistOp::operator()() {
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
// Master thread goes to sleep after it has made all the IOBlocks
wp_
.
Clear
();
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
}
}
}
}
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
浏览文件 @
beefb20c
...
@@ -55,14 +55,14 @@ Status DistributedSampler::InitSampler() {
...
@@ -55,14 +55,14 @@ Status DistributedSampler::InitSampler() {
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
DistributedSampler
::
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
Status
DistributedSampler
::
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
cnt_
>
samples_per_buffer_
)
{
if
(
cnt_
>
samples_per_buffer_
)
{
RETURN_STATUS_UNEXPECTED
(
"Distributed Sampler Error"
);
RETURN_STATUS_UNEXPECTED
(
"Distributed Sampler Error"
);
}
else
if
(
cnt_
==
samples_per_buffer_
)
{
}
else
if
(
cnt_
==
samples_per_buffer_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
}
else
{
if
(
HasChildSampler
())
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Buffer
(
&
child_ids_
));
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Sample
(
&
child_ids_
));
}
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
cnt_
,
DataBuffer
::
kDeBFlagNone
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
cnt_
,
DataBuffer
::
kDeBFlagNone
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h
浏览文件 @
beefb20c
...
@@ -40,7 +40,7 @@ class DistributedSampler : public Sampler {
...
@@ -40,7 +40,7 @@ class DistributedSampler : public Sampler {
// @param std::unique_ptr<DataBuffer> * pBuffer
// @param std::unique_ptr<DataBuffer> * pBuffer
// @param int32_t workerId
// @param int32_t workerId
// @return - The error code return
// @return - The error code return
Status
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
Status
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// Init sampler, called by base class or python
// Init sampler, called by base class or python
Status
InitSampler
()
override
;
Status
InitSampler
()
override
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
浏览文件 @
beefb20c
...
@@ -59,14 +59,14 @@ Status PKSampler::InitSampler() {
...
@@ -59,14 +59,14 @@ Status PKSampler::InitSampler() {
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
PKSampler
::
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
Status
PKSampler
::
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
next_id_
>
num_samples_
||
num_samples_
==
0
)
{
if
(
next_id_
>
num_samples_
||
num_samples_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"Index out of bound in PKSampler"
);
RETURN_STATUS_UNEXPECTED
(
"Index out of bound in PKSampler"
);
}
else
if
(
next_id_
==
num_samples_
)
{
}
else
if
(
next_id_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
}
else
{
if
(
HasChildSampler
())
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Buffer
(
&
child_ids_
));
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Sample
(
&
child_ids_
));
}
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
next_id_
,
DataBuffer
::
kDeBFlagNone
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
next_id_
,
DataBuffer
::
kDeBFlagNone
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
浏览文件 @
beefb20c
...
@@ -41,7 +41,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED
...
@@ -41,7 +41,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED
// @param std::unique_ptr<DataBuffer pBuffer
// @param std::unique_ptr<DataBuffer pBuffer
// @param int32_t workerId
// @param int32_t workerId
// @return - The error code return
// @return - The error code return
Status
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
Status
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// in the dataset that we can sample from.
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
浏览文件 @
beefb20c
...
@@ -23,12 +23,12 @@ namespace dataset {
...
@@ -23,12 +23,12 @@ namespace dataset {
PythonSampler
::
PythonSampler
(
int64_t
num_samples
,
py
::
object
py_sampler_instance
,
int64_t
samples_per_buffer
)
PythonSampler
::
PythonSampler
(
int64_t
num_samples
,
py
::
object
py_sampler_instance
,
int64_t
samples_per_buffer
)
:
Sampler
(
num_samples
,
samples_per_buffer
),
py_sampler_instance
(
py_sampler_instance
),
need_to_reset_
(
false
)
{}
:
Sampler
(
num_samples
,
samples_per_buffer
),
py_sampler_instance
(
py_sampler_instance
),
need_to_reset_
(
false
)
{}
Status
PythonSampler
::
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
Status
PythonSampler
::
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
need_to_reset_
)
{
if
(
need_to_reset_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
}
else
{
if
(
HasChildSampler
())
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Buffer
(
&
child_ids_
));
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Sample
(
&
child_ids_
));
}
}
std
::
shared_ptr
<
Tensor
>
sample_ids
;
std
::
shared_ptr
<
Tensor
>
sample_ids
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
浏览文件 @
beefb20c
...
@@ -48,7 +48,7 @@ class PythonSampler : public Sampler {
...
@@ -48,7 +48,7 @@ class PythonSampler : public Sampler {
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @param int32_t workerId - not meant to be used
// @return - The error code return
// @return - The error code return
Status
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
Status
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
private:
private:
bool
need_to_reset_
;
// Whether Reset() should be called before calling GetNextBuffer()
bool
need_to_reset_
;
// Whether Reset() should be called before calling GetNextBuffer()
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
浏览文件 @
beefb20c
...
@@ -31,14 +31,14 @@ RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuff
...
@@ -31,14 +31,14 @@ RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuff
reshuffle_each_epoch_
(
reshuffle_each_epoch
),
reshuffle_each_epoch_
(
reshuffle_each_epoch
),
dist
(
nullptr
)
{}
dist
(
nullptr
)
{}
Status
RandomSampler
::
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
Status
RandomSampler
::
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
next_id_
>
num_samples_
)
{
if
(
next_id_
>
num_samples_
)
{
RETURN_STATUS_UNEXPECTED
(
"RandomSampler Internal Error"
);
RETURN_STATUS_UNEXPECTED
(
"RandomSampler Internal Error"
);
}
else
if
(
next_id_
==
num_samples_
)
{
}
else
if
(
next_id_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
}
else
{
if
(
HasChildSampler
())
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Buffer
(
&
child_ids_
));
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Sample
(
&
child_ids_
));
}
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
next_id_
,
DataBuffer
::
kDeBFlagNone
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
next_id_
,
DataBuffer
::
kDeBFlagNone
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h
浏览文件 @
beefb20c
...
@@ -41,7 +41,7 @@ class RandomSampler : public Sampler {
...
@@ -41,7 +41,7 @@ class RandomSampler : public Sampler {
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @param int32_t workerId - not meant to be used
// @return - The error code return
// @return - The error code return
Status
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
Status
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// meant to be called by base class or python
// meant to be called by base class or python
Status
InitSampler
()
override
;
Status
InitSampler
()
override
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
浏览文件 @
beefb20c
...
@@ -33,11 +33,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
...
@@ -33,11 +33,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
}
}
Sampler
::
Sampler
(
int64_t
num_samples
,
int64_t
samples_per_buffer
)
Sampler
::
Sampler
(
int64_t
num_samples
,
int64_t
samples_per_buffer
)
:
DatasetOp
(
0
),
:
num_rows_
(
0
),
num_samples_
(
num_samples
),
samples_per_buffer_
(
samples_per_buffer
),
col_desc_
(
nullptr
)
{}
num_rows_
(
0
),
num_samples_
(
num_samples
),
samples_per_buffer_
(
samples_per_buffer
),
col_desc_
(
nullptr
)
{}
Status
Sampler
::
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
{
Status
Sampler
::
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
{
std
::
shared_ptr
<
Sampler
>
child_sampler
;
std
::
shared_ptr
<
Sampler
>
child_sampler
;
...
@@ -97,7 +93,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
...
@@ -97,7 +93,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
std
::
shared_ptr
<
Tensor
>
sample_ids
;
std
::
shared_ptr
<
Tensor
>
sample_ids
;
// A call to derived class to get sample ids wrapped inside a buffer
// A call to derived class to get sample ids wrapped inside a buffer
RETURN_IF_NOT_OK
(
GetNext
Buffer
(
&
db
));
RETURN_IF_NOT_OK
(
GetNext
Sample
(
&
db
));
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
RETURN_IF_NOT_OK
(
db
->
GetTensor
(
&
sample_ids
,
0
,
0
));
RETURN_IF_NOT_OK
(
db
->
GetTensor
(
&
sample_ids
,
0
,
0
));
// check this buffer is not a ctrl buffer
// check this buffer is not a ctrl buffer
...
@@ -114,7 +110,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
...
@@ -114,7 +110,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
}
}
}
}
// perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch
// perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch
RETURN_IF_NOT_OK
(
GetNext
Buffer
(
&
db
));
RETURN_IF_NOT_OK
(
GetNext
Sample
(
&
db
));
CHECK_FAIL_RETURN_UNEXPECTED
(
db
->
eoe
(),
"ERROR Non EOE received"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
db
->
eoe
(),
"ERROR Non EOE received"
);
// Reset Sampler since this is the end of the epoch
// Reset Sampler since this is the end of the epoch
RETURN_IF_NOT_OK
(
Reset
());
RETURN_IF_NOT_OK
(
Reset
());
...
@@ -133,17 +129,7 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
...
@@ -133,17 +129,7 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
return
Status
::
OK
();
return
Status
::
OK
();
}
}
// inline op doesn't have it's own consumer, it's assigned from parent
Status
Sampler
::
AddChild
(
std
::
shared_ptr
<
Sampler
>
child
)
{
int32_t
Sampler
::
num_consumers
()
const
{
if
(
parent_
.
empty
()
||
parent_
[
0
]
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"Sampler with no parent. num_consumers is 0."
;
return
0
;
}
else
{
return
parent_
[
0
]
->
num_consumers
();
}
}
Status
Sampler
::
AddChild
(
std
::
shared_ptr
<
DatasetOp
>
child
)
{
if
(
child
==
nullptr
)
{
if
(
child
==
nullptr
)
{
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -182,14 +168,5 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
...
@@ -182,14 +168,5 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
return
Status
::
OK
();
return
Status
::
OK
();
}
}
// inline op doesn't have it's own producers, it's assigned from child
int32_t
Sampler
::
num_producers
()
const
{
if
(
child_
.
empty
()
||
child_
[
0
]
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"Sampler with no child, num_producers is 0."
;
return
0
;
}
else
{
return
child_
[
0
]
->
num_producers
();
}
}
}
// namespace dataset
}
// namespace dataset
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h
浏览文件 @
beefb20c
...
@@ -54,7 +54,7 @@ class RandomAccessOp {
...
@@ -54,7 +54,7 @@ class RandomAccessOp {
int64_t
num_rows_
;
int64_t
num_rows_
;
};
};
class
Sampler
:
public
DatasetOp
{
class
Sampler
{
public:
public:
// Constructor
// Constructor
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
...
@@ -70,14 +70,14 @@ class Sampler : public DatasetOp {
...
@@ -70,14 +70,14 @@ class Sampler : public DatasetOp {
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @param int32_t workerId - not meant to be used
// @return - The error code return
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
=
0
;
virtual
Status
GetNextSample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
=
0
;
// return all ids in one epoch as a numpy array, then call reset
// return all ids in one epoch as a numpy array, then call reset
Status
GetAllIdsThenReset
(
py
::
array
*
data
);
Status
GetAllIdsThenReset
(
py
::
array
*
data
);
// for next epoch of sampleIds
// for next epoch of sampleIds
// @return - The error code return
// @return - The error code return
Status
Reset
()
override
=
0
;
virtual
Status
Reset
()
=
0
;
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// in the dataset that we can sample from.
...
@@ -98,26 +98,10 @@ class Sampler : public DatasetOp {
...
@@ -98,26 +98,10 @@ class Sampler : public DatasetOp {
// @return status error code
// @return status error code
Status
SetNumRowsInDataset
(
int64_t
num_rows
);
Status
SetNumRowsInDataset
(
int64_t
num_rows
);
// Sampler is an inlined op and has no workers. Producers and consumers are computed.
// @return
int32_t
num_workers
()
const
final
{
return
0
;
}
// Identify num consumers (inlined op)
// @return
int32_t
num_consumers
()
const
final
;
// Identify num producers (inlined op)
// @return
int32_t
num_producers
()
const
final
;
// Not meant to be called!
// @return - The error code return
Status
operator
()()
final
{
RETURN_STATUS_UNEXPECTED
(
"Functor not supported in Sampler"
);
}
// Adds a sampler to become our child.
// Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned.
// @return - The error code returned.
Status
AddChild
(
std
::
shared_ptr
<
DatasetOp
>
child
);
Status
AddChild
(
std
::
shared_ptr
<
Sampler
>
child
);
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds
// @param std::shared_ptr<Tensor>* sampleIds
...
@@ -125,7 +109,7 @@ class Sampler : public DatasetOp {
...
@@ -125,7 +109,7 @@ class Sampler : public DatasetOp {
// @return - The error code returned.
// @return - The error code returned.
Status
CreateSamplerTensor
(
std
::
shared_ptr
<
Tensor
>
*
sample_ids
,
int64_t
num_elements
);
Status
CreateSamplerTensor
(
std
::
shared_ptr
<
Tensor
>
*
sample_ids
,
int64_t
num_elements
);
v
oid
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
v
irtual
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Sampler
&
sampler
)
{
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Sampler
&
sampler
)
{
sampler
.
Print
(
out
,
false
);
sampler
.
Print
(
out
,
false
);
...
@@ -156,6 +140,7 @@ class Sampler : public DatasetOp {
...
@@ -156,6 +140,7 @@ class Sampler : public DatasetOp {
int64_t
samples_per_buffer_
;
int64_t
samples_per_buffer_
;
std
::
unique_ptr
<
ColDescriptor
>
col_desc_
;
std
::
unique_ptr
<
ColDescriptor
>
col_desc_
;
std
::
vector
<
std
::
shared_ptr
<
Sampler
>>
child_
;
// Child nodes
std
::
unique_ptr
<
DataBuffer
>
child_ids_
;
std
::
unique_ptr
<
DataBuffer
>
child_ids_
;
};
};
}
// namespace dataset
}
// namespace dataset
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
浏览文件 @
beefb20c
...
@@ -23,14 +23,14 @@ namespace dataset {
...
@@ -23,14 +23,14 @@ namespace dataset {
SequentialSampler
::
SequentialSampler
(
int64_t
num_samples
,
int64_t
start_index
,
int64_t
samples_per_buffer
)
SequentialSampler
::
SequentialSampler
(
int64_t
num_samples
,
int64_t
start_index
,
int64_t
samples_per_buffer
)
:
Sampler
(
num_samples
,
samples_per_buffer
),
start_index_
(
start_index
),
current_id_
(
start_index
),
id_count_
(
0
)
{}
:
Sampler
(
num_samples
,
samples_per_buffer
),
start_index_
(
start_index
),
current_id_
(
start_index
),
id_count_
(
0
)
{}
Status
SequentialSampler
::
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
Status
SequentialSampler
::
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
id_count_
>
num_samples_
)
{
if
(
id_count_
>
num_samples_
)
{
RETURN_STATUS_UNEXPECTED
(
"SequentialSampler Internal Error"
);
RETURN_STATUS_UNEXPECTED
(
"SequentialSampler Internal Error"
);
}
else
if
(
id_count_
==
num_samples_
)
{
}
else
if
(
id_count_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
}
else
{
if
(
HasChildSampler
())
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Buffer
(
&
child_ids_
));
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Sample
(
&
child_ids_
));
}
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
current_id_
,
DataBuffer
::
kDeBFlagNone
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
current_id_
,
DataBuffer
::
kDeBFlagNone
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
浏览文件 @
beefb20c
...
@@ -47,7 +47,7 @@ class SequentialSampler : public Sampler {
...
@@ -47,7 +47,7 @@ class SequentialSampler : public Sampler {
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @param int32_t workerId - not meant to be used
// @return - The error code return
// @return - The error code return
Status
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
Status
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
浏览文件 @
beefb20c
...
@@ -72,13 +72,13 @@ Status SubsetRandomSampler::Reset() {
...
@@ -72,13 +72,13 @@ Status SubsetRandomSampler::Reset() {
}
}
// Get the sample ids.
// Get the sample ids.
Status
SubsetRandomSampler
::
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
Status
SubsetRandomSampler
::
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
// All samples have been drawn
// All samples have been drawn
if
(
sample_id_
==
num_samples_
)
{
if
(
sample_id_
==
num_samples_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
}
else
{
if
(
HasChildSampler
())
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Buffer
(
&
child_ids_
));
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Sample
(
&
child_ids_
));
}
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagNone
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagNone
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
浏览文件 @
beefb20c
...
@@ -49,7 +49,7 @@ class SubsetRandomSampler : public Sampler {
...
@@ -49,7 +49,7 @@ class SubsetRandomSampler : public Sampler {
// Get the sample ids.
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
Status
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
Status
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
private:
private:
// A list of indices (already randomized in constructor).
// A list of indices (already randomized in constructor).
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
浏览文件 @
beefb20c
...
@@ -95,7 +95,7 @@ Status WeightedRandomSampler::Reset() {
...
@@ -95,7 +95,7 @@ Status WeightedRandomSampler::Reset() {
}
}
// Get the sample ids.
// Get the sample ids.
Status
WeightedRandomSampler
::
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
Status
WeightedRandomSampler
::
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
weights_
.
size
()
>
static_cast
<
size_t
>
(
num_rows_
))
{
if
(
weights_
.
size
()
>
static_cast
<
size_t
>
(
num_rows_
))
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors"
);
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors"
);
...
@@ -109,7 +109,7 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
...
@@ -109,7 +109,7 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
}
else
{
if
(
HasChildSampler
())
{
if
(
HasChildSampler
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Buffer
(
&
child_ids_
));
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNext
Sample
(
&
child_ids_
));
}
}
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagNone
);
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id_
++
,
DataBuffer
::
kDeBFlagNone
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
浏览文件 @
beefb20c
...
@@ -51,7 +51,7 @@ class WeightedRandomSampler : public Sampler {
...
@@ -51,7 +51,7 @@ class WeightedRandomSampler : public Sampler {
// Get the sample ids.
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
Status
GetNext
Buffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
Status
GetNext
Sample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
private:
private:
// A list of weights for each sample.
// A list of weights for each sample.
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
beefb20c
...
@@ -123,7 +123,7 @@ Status VOCOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::
...
@@ -123,7 +123,7 @@ Status VOCOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::
Status
VOCOp
::
operator
()()
{
Status
VOCOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
RETURN_IF_NOT_OK
(
LaunchThreadsAndInitOp
());
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
while
(
true
)
{
while
(
true
)
{
std
::
vector
<
int64_t
>
keys
;
std
::
vector
<
int64_t
>
keys
;
keys
.
reserve
(
rows_per_buffer_
);
keys
.
reserve
(
rows_per_buffer_
);
...
@@ -134,7 +134,7 @@ Status VOCOp::operator()() {
...
@@ -134,7 +134,7 @@ Status VOCOp::operator()() {
RETURN_STATUS_UNEXPECTED
(
"Sampler Tensor isn't int64"
);
RETURN_STATUS_UNEXPECTED
(
"Sampler Tensor isn't int64"
);
}
}
RETURN_IF_NOT_OK
(
TraverseSampleIds
(
sample_ids
,
&
keys
));
RETURN_IF_NOT_OK
(
TraverseSampleIds
(
sample_ids
,
&
keys
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
}
}
if
(
keys
.
empty
()
==
false
)
{
if
(
keys
.
empty
()
==
false
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
...
@@ -155,7 +155,7 @@ Status VOCOp::operator()() {
...
@@ -155,7 +155,7 @@ Status VOCOp::operator()() {
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
io_block_queues_
[(
buf_cnt_
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
wp_
.
Clear
();
wp_
.
Clear
();
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Buffer
(
&
sampler_buffer
));
RETURN_IF_NOT_OK
(
sampler_
->
GetNext
Sample
(
&
sampler_buffer
));
}
}
}
}
}
}
...
...
tests/ut/cpp/dataset/stand_alone_samplers_test.cc
浏览文件 @
beefb20c
...
@@ -68,7 +68,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
...
@@ -68,7 +68,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
DistributedSampler
>
(
num_samples
,
3
,
i
%
3
,
(
i
<
3
?
false
:
true
));
std
::
shared_ptr
<
Sampler
>
sampler
=
std
::
make_shared
<
DistributedSampler
>
(
num_samples
,
3
,
i
%
3
,
(
i
<
3
?
false
:
true
));
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
sampler
->
GetNext
Buffer
(
&
db
);
sampler
->
GetNext
Sample
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
MS_LOG
(
DEBUG
)
<<
(
*
tensor
);
MS_LOG
(
DEBUG
)
<<
(
*
tensor
);
if
(
i
<
3
)
{
// This is added due to std::shuffle()
if
(
i
<
3
)
{
// This is added due to std::shuffle()
...
@@ -90,17 +90,17 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
...
@@ -90,17 +90,17 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
shared_ptr
<
Tensor
>
tensor
;
std
::
shared_ptr
<
Tensor
>
tensor
;
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
sampler
->
HandshakeRandomAccessOp
(
&
mock
);
sampler
->
GetNext
Buffer
(
&
db
);
sampler
->
GetNext
Sample
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
EXPECT_TRUE
((
*
tensor
)
==
(
*
label1
));
EXPECT_TRUE
((
*
tensor
)
==
(
*
label1
));
sampler
->
GetNext
Buffer
(
&
db
);
sampler
->
GetNext
Sample
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
EXPECT_TRUE
((
*
tensor
)
==
(
*
label2
));
EXPECT_TRUE
((
*
tensor
)
==
(
*
label2
));
sampler
->
Reset
();
sampler
->
Reset
();
sampler
->
GetNext
Buffer
(
&
db
);
sampler
->
GetNext
Sample
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
EXPECT_TRUE
((
*
tensor
)
==
(
*
label1
));
EXPECT_TRUE
((
*
tensor
)
==
(
*
label1
));
sampler
->
GetNext
Buffer
(
&
db
);
sampler
->
GetNext
Sample
(
&
db
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
db
->
GetTensor
(
&
tensor
,
0
,
0
);
EXPECT_TRUE
((
*
tensor
)
==
(
*
label2
));
EXPECT_TRUE
((
*
tensor
)
==
(
*
label2
));
}
}
tests/ut/cpp/dataset/subset_random_sampler_test.cc
浏览文件 @
beefb20c
...
@@ -49,7 +49,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
...
@@ -49,7 +49,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
std
::
vector
<
int64_t
>
out
;
std
::
vector
<
int64_t
>
out
;
ASSERT_EQ
(
sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
int64_t
>
();
it
!=
t
->
end
<
int64_t
>
();
it
++
)
{
for
(
auto
it
=
t
->
begin
<
int64_t
>
();
it
!=
t
->
end
<
int64_t
>
();
it
++
)
{
...
@@ -61,7 +61,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
...
@@ -61,7 +61,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
ASSERT_NE
(
in_set
.
find
(
out
[
i
]),
in_set
.
end
());
ASSERT_NE
(
in_set
.
find
(
out
[
i
]),
in_set
.
end
());
}
}
ASSERT_EQ
(
sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
ASSERT_EQ
(
db
->
eoe
(),
true
);
}
}
...
@@ -79,7 +79,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
...
@@ -79,7 +79,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
TensorRow
row
;
TensorRow
row
;
std
::
vector
<
int64_t
>
out
;
std
::
vector
<
int64_t
>
out
;
ASSERT_EQ
(
sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
int
epoch
=
0
;
int
epoch
=
0
;
while
(
!
db
->
eoe
())
{
while
(
!
db
->
eoe
())
{
epoch
++
;
epoch
++
;
...
@@ -91,7 +91,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
...
@@ -91,7 +91,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
}
}
db
.
reset
();
db
.
reset
();
ASSERT_EQ
(
sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
}
}
ASSERT_EQ
(
epoch
,
(
total_samples
+
samples_per_buffer
-
1
)
/
samples_per_buffer
);
ASSERT_EQ
(
epoch
,
(
total_samples
+
samples_per_buffer
-
1
)
/
samples_per_buffer
);
...
@@ -111,7 +111,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
...
@@ -111,7 +111,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
TensorRow
row
;
TensorRow
row
;
std
::
vector
<
int64_t
>
out
;
std
::
vector
<
int64_t
>
out
;
ASSERT_EQ
(
sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
int64_t
>
();
it
!=
t
->
end
<
int64_t
>
();
it
++
)
{
for
(
auto
it
=
t
->
begin
<
int64_t
>
();
it
!=
t
->
end
<
int64_t
>
();
it
++
)
{
...
@@ -125,7 +125,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
...
@@ -125,7 +125,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
sampler
.
Reset
();
sampler
.
Reset
();
ASSERT_EQ
(
sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
false
);
ASSERT_EQ
(
db
->
eoe
(),
false
);
db
->
PopRow
(
&
row
);
db
->
PopRow
(
&
row
);
out
.
clear
();
out
.
clear
();
...
@@ -139,6 +139,6 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
...
@@ -139,6 +139,6 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
ASSERT_NE
(
in_set
.
find
(
out
[
i
]),
in_set
.
end
());
ASSERT_NE
(
in_set
.
find
(
out
[
i
]),
in_set
.
end
());
}
}
ASSERT_EQ
(
sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
ASSERT_EQ
(
db
->
eoe
(),
true
);
}
}
tests/ut/cpp/dataset/weighted_random_sampler_test.cc
浏览文件 @
beefb20c
...
@@ -58,7 +58,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
...
@@ -58,7 +58,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
std
::
vector
<
uint64_t
>
out
;
std
::
vector
<
uint64_t
>
out
;
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
...
@@ -69,7 +69,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
...
@@ -69,7 +69,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
ASSERT_EQ
(
num_samples
,
out
.
size
());
ASSERT_EQ
(
num_samples
,
out
.
size
());
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
ASSERT_EQ
(
db
->
eoe
(),
true
);
}
}
...
@@ -88,7 +88,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
...
@@ -88,7 +88,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
std
::
vector
<
uint64_t
>
out
;
std
::
vector
<
uint64_t
>
out
;
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
...
@@ -105,7 +105,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
...
@@ -105,7 +105,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
}
}
}
}
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
ASSERT_EQ
(
db
->
eoe
(),
true
);
}
}
...
@@ -124,7 +124,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
...
@@ -124,7 +124,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
std
::
vector
<
uint64_t
>
out
;
std
::
vector
<
uint64_t
>
out
;
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
int
epoch
=
0
;
int
epoch
=
0
;
while
(
!
db
->
eoe
())
{
while
(
!
db
->
eoe
())
{
epoch
++
;
epoch
++
;
...
@@ -135,7 +135,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
...
@@ -135,7 +135,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
}
}
}
}
db
.
reset
();
db
.
reset
();
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
}
}
ASSERT_EQ
(
epoch
,
(
num_samples
+
samples_per_buffer
-
1
)
/
samples_per_buffer
);
ASSERT_EQ
(
epoch
,
(
num_samples
+
samples_per_buffer
-
1
)
/
samples_per_buffer
);
...
@@ -160,7 +160,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
...
@@ -160,7 +160,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
std
::
vector
<
uint64_t
>
out
;
std
::
vector
<
uint64_t
>
out
;
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
int
epoch
=
0
;
int
epoch
=
0
;
while
(
!
db
->
eoe
())
{
while
(
!
db
->
eoe
())
{
epoch
++
;
epoch
++
;
...
@@ -172,7 +172,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
...
@@ -172,7 +172,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
}
}
}
}
db
.
reset
();
db
.
reset
();
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
}
}
// Without replacement, each sample only drawn once.
// Without replacement, each sample only drawn once.
...
@@ -201,7 +201,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
...
@@ -201,7 +201,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
std
::
vector
<
uint64_t
>
out
;
std
::
vector
<
uint64_t
>
out
;
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
...
@@ -211,13 +211,13 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
...
@@ -211,13 +211,13 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
}
}
ASSERT_EQ
(
num_samples
,
out
.
size
());
ASSERT_EQ
(
num_samples
,
out
.
size
());
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
ASSERT_EQ
(
db
->
eoe
(),
true
);
m_sampler
.
Reset
();
m_sampler
.
Reset
();
out
.
clear
();
out
.
clear
();
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
...
@@ -227,7 +227,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
...
@@ -227,7 +227,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
}
}
ASSERT_EQ
(
num_samples
,
out
.
size
());
ASSERT_EQ
(
num_samples
,
out
.
size
());
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
ASSERT_EQ
(
db
->
eoe
(),
true
);
}
}
...
@@ -246,7 +246,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
...
@@ -246,7 +246,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
TensorRow
row
;
std
::
vector
<
uint64_t
>
out
;
std
::
vector
<
uint64_t
>
out
;
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
...
@@ -256,7 +256,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
...
@@ -256,7 +256,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
}
}
ASSERT_EQ
(
num_samples
,
out
.
size
());
ASSERT_EQ
(
num_samples
,
out
.
size
());
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
ASSERT_EQ
(
db
->
eoe
(),
true
);
m_sampler
.
Reset
();
m_sampler
.
Reset
();
...
@@ -265,7 +265,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
...
@@ -265,7 +265,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
freq
.
resize
(
total_samples
,
0
);
freq
.
resize
(
total_samples
,
0
);
MS_LOG
(
INFO
)
<<
"Resetting sampler"
;
MS_LOG
(
INFO
)
<<
"Resetting sampler"
;
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
...
@@ -282,6 +282,6 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
...
@@ -282,6 +282,6 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
}
}
}
}
ASSERT_EQ
(
m_sampler
.
GetNext
Buffer
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
m_sampler
.
GetNext
Sample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
ASSERT_EQ
(
db
->
eoe
(),
true
);
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录