Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6945eb28
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6945eb28
编写于
8月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3622 Added option to distributed sampler
Merge pull request !3622 from EricZ/distributed_sampler_fix
上级
2449e4e7
8c018da4
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
175 addition
and
23 deletion
+175
-23
mindspore/ccsrc/minddata/dataset/api/samplers.cc
mindspore/ccsrc/minddata/dataset/api/samplers.cc
+11
-5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
...t/engine/datasetops/source/sampler/distributed_sampler.cc
+12
-3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h
...et/engine/datasetops/source/sampler/distributed_sampler.h
+19
-12
mindspore/ccsrc/minddata/dataset/include/samplers.h
mindspore/ccsrc/minddata/dataset/include/samplers.h
+7
-2
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+3
-1
tests/ut/cpp/dataset/distributed_sampler_test.cc
tests/ut/cpp/dataset/distributed_sampler_test.cc
+123
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/samplers.cc
浏览文件 @
6945eb28
...
...
@@ -31,8 +31,8 @@ SamplerObj::SamplerObj() {}
/// Function to create a Distributed Sampler.
std
::
shared_ptr
<
DistributedSamplerObj
>
DistributedSampler
(
int64_t
num_shards
,
int64_t
shard_id
,
bool
shuffle
,
int64_t
num_samples
,
uint32_t
seed
)
{
auto
sampler
=
std
::
make_shared
<
DistributedSamplerObj
>
(
num_shards
,
shard_id
,
shuffle
,
num_samples
,
seed
);
int64_t
num_samples
,
uint32_t
seed
,
bool
even_dist
)
{
auto
sampler
=
std
::
make_shared
<
DistributedSamplerObj
>
(
num_shards
,
shard_id
,
shuffle
,
num_samples
,
seed
,
even_dist
);
// Input validation
if
(
!
sampler
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -95,8 +95,13 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto
// DistributedSampler
DistributedSamplerObj
::
DistributedSamplerObj
(
int64_t
num_shards
,
int64_t
shard_id
,
bool
shuffle
,
int64_t
num_samples
,
uint32_t
seed
)
:
num_shards_
(
num_shards
),
shard_id_
(
shard_id
),
shuffle_
(
shuffle
),
num_samples_
(
num_samples
),
seed_
(
seed
)
{}
uint32_t
seed
,
bool
even_dist
)
:
num_shards_
(
num_shards
),
shard_id_
(
shard_id
),
shuffle_
(
shuffle
),
num_samples_
(
num_samples
),
seed_
(
seed
),
even_dist_
(
even_dist
)
{}
bool
DistributedSamplerObj
::
ValidateParams
()
{
if
(
num_shards_
<=
0
)
{
...
...
@@ -118,7 +123,8 @@ bool DistributedSamplerObj::ValidateParams() {
}
std
::
shared_ptr
<
Sampler
>
DistributedSamplerObj
::
Build
()
{
return
std
::
make_shared
<
dataset
::
DistributedSampler
>
(
num_samples_
,
num_shards_
,
shard_id_
,
shuffle_
,
seed_
);
return
std
::
make_shared
<
dataset
::
DistributedSampler
>
(
num_samples_
,
num_shards_
,
shard_id_
,
shuffle_
,
seed_
,
even_dist_
);
}
// PKSampler
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
浏览文件 @
6945eb28
...
...
@@ -24,13 +24,14 @@
namespace
mindspore
{
namespace
dataset
{
DistributedSampler
::
DistributedSampler
(
int64_t
num_samples
,
int64_t
num_dev
,
int64_t
dev_id
,
bool
shuffle
,
uint32_t
seed
)
uint32_t
seed
,
bool
even_dist
)
:
Sampler
(
num_samples
,
std
::
numeric_limits
<
int64_t
>::
max
()),
cnt_
(
0
),
seed_
(
seed
==
std
::
numeric_limits
<
uint32_t
>::
max
()
?
GetSeed
()
:
seed
),
device_id_
(
dev_id
),
num_devices_
(
num_dev
),
shuffle_
(
shuffle
)
{}
shuffle_
(
shuffle
),
even_dist_
(
even_dist
)
{}
Status
DistributedSampler
::
InitSampler
()
{
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
...
...
@@ -43,7 +44,15 @@ Status DistributedSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED
(
device_id_
<
num_devices_
&&
device_id_
>=
0
&&
num_rows_
>
0
&&
num_samples_
>
0
,
"fail to init DistributedSampler"
);
rnd_
.
seed
(
seed_
++
);
samples_per_buffer_
=
(
num_rows_
+
num_devices_
-
1
)
/
num_devices_
;
// equals to ceil(num_rows/num_devices)
if
(
even_dist_
)
{
samples_per_buffer_
=
(
num_rows_
+
num_devices_
-
1
)
/
num_devices_
;
// equals to ceil(num_rows/num_devices)
}
else
{
int64_t
mod
=
num_rows_
%
num_devices_
;
samples_per_buffer_
=
num_rows_
/
num_devices_
;
if
(
mod
>
device_id_
)
{
samples_per_buffer_
++
;
}
}
samples_per_buffer_
=
num_samples_
<
samples_per_buffer_
?
num_samples_
:
samples_per_buffer_
;
if
(
shuffle_
==
true
)
{
shuffle_vec_
.
reserve
(
num_rows_
);
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h
浏览文件 @
6945eb28
...
...
@@ -27,26 +27,32 @@ namespace mindspore {
namespace
dataset
{
class
DistributedSampler
:
public
Sampler
{
public:
// @param num_samples
// @param int64_t num_dev
// @param int64_t dev_id
// @param bool shuffle
/// \brief Constructor
/// \param[in] num_samples The total number of rows in the dataset
/// \param[in] num_dev Total number of shards for the distributed sampler
/// \param[in] dev_id Device id of the shard
/// \param[in] shuffle Option to shuffle
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
/// result in different samples being picked
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
/// This option is not exposed in the python API. Current behavior is that the remainder will always
/// be handled by the first n shards, n being the corresponding device id.
DistributedSampler
(
int64_t
num_samples
,
int64_t
num_dev
,
int64_t
dev_id
,
bool
shuffle
,
uint32_t
seed
=
std
::
numeric_limits
<
uint32_t
>::
max
());
uint32_t
seed
=
std
::
numeric_limits
<
uint32_t
>::
max
()
,
bool
even_dist
=
true
);
// default destructor
//
/ \brief
default destructor
~
DistributedSampler
()
=
default
;
//
@
param std::unique_ptr<DataBuffer> * pBuffer
//
@
param int32_t workerId
//
@return - The error code return
//
/ \
param std::unique_ptr<DataBuffer> * pBuffer
//
/ \
param int32_t workerId
//
/ \return Status code
Status
GetNextSample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// Init sampler, called by base class or python
//
/
Init sampler, called by base class or python
Status
InitSampler
()
override
;
// for next epoch of sampleIds
//
@return - The error code return
//
/ \brief
for next epoch of sampleIds
//
/ \return Status code
Status
ResetSampler
()
override
;
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
...
...
@@ -59,6 +65,7 @@ class DistributedSampler : public Sampler {
bool
shuffle_
;
std
::
mt19937
rnd_
;
std
::
vector
<
int64_t
>
shuffle_vec_
;
bool
even_dist_
;
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/dataset/include/samplers.h
浏览文件 @
6945eb28
...
...
@@ -52,9 +52,12 @@ class WeightedRandomSamplerObj;
/// \param[in] shuffle - If true, the indices are shuffled.
/// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \param[in] seed - The seed in use when shuffle is true.
/// \param[in] even_dist - If true, each shard would return the same number of rows (default to true).
/// If false the total rows returned by all the shards would not have overlap.
/// \return Shared pointer to the current Sampler.
std
::
shared_ptr
<
DistributedSamplerObj
>
DistributedSampler
(
int64_t
num_shards
,
int64_t
shard_id
,
bool
shuffle
=
true
,
int64_t
num_samples
=
0
,
uint32_t
seed
=
1
);
int64_t
num_samples
=
0
,
uint32_t
seed
=
1
,
bool
even_dist
=
true
);
/// Function to create a PK Sampler.
/// \notes Samples K elements for each P class in the dataset.
...
...
@@ -100,7 +103,8 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto
/* ####################################### Derived Sampler classes ################################# */
class
DistributedSamplerObj
:
public
SamplerObj
{
public:
DistributedSamplerObj
(
int64_t
num_shards
,
int64_t
shard_id
,
bool
shuffle
,
int64_t
num_samples
,
uint32_t
seed
);
DistributedSamplerObj
(
int64_t
num_shards
,
int64_t
shard_id
,
bool
shuffle
,
int64_t
num_samples
,
uint32_t
seed
,
bool
even_dist
);
~
DistributedSamplerObj
()
=
default
;
...
...
@@ -114,6 +118,7 @@ class DistributedSamplerObj : public SamplerObj {
bool
shuffle_
;
int64_t
num_samples_
;
uint32_t
seed_
;
bool
even_dist_
;
};
class
PKSamplerObj
:
public
SamplerObj
{
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
6945eb28
...
...
@@ -92,7 +92,9 @@ SET(DE_UT_SRCS
tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc
epoch_ctrl_op_test.cc
swap_red_blue_test.cc
sentence_piece_vocab_op_test.cc
swap_red_blue_test.cc
distributed_sampler_test.cc
)
if
(
ENABLE_PYTHON
)
...
...
tests/ut/cpp/dataset/distributed_sampler_test.cc
0 → 100644
浏览文件 @
6945eb28
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "gtest/gtest.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "utils/log_adapter.h"
#include <vector>
#include <unordered_set>
using
namespace
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
class
MindDataTestDistributedSampler
:
public
UT
::
Common
{
public:
class
DummyRandomAccessOp
:
public
RandomAccessOp
{
public:
DummyRandomAccessOp
(
uint64_t
num_rows
)
{
// row count is in base class as protected member
// GetNumRowsInDataset does not need an override, the default from base class is fine.
num_rows_
=
num_rows
;
}
};
};
TEST_F
(
MindDataTestDistributedSampler
,
TestTwoShardsOne
)
{
// num samples to draw.
uint64_t
num_samples
=
7
;
// create sampler with replacement = true
DistributedSampler
m_sampler
(
num_samples
,
2
,
0
,
false
,
0
,
false
);
DummyRandomAccessOp
dummyRandomAccessOp
(
num_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
std
::
vector
<
uint64_t
>
out
;
ASSERT_EQ
(
m_sampler
.
GetNextSample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
out
.
push_back
(
*
it
);
}
}
ASSERT_EQ
(
4
,
out
.
size
());
ASSERT_EQ
(
m_sampler
.
GetNextSample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
}
TEST_F
(
MindDataTestDistributedSampler
,
TestTwoShardsTwo
)
{
// num samples to draw.
uint64_t
num_samples
=
7
;
// create sampler with replacement = true
DistributedSampler
m_sampler
(
num_samples
,
2
,
1
,
false
,
0
,
false
);
DummyRandomAccessOp
dummyRandomAccessOp
(
num_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
std
::
vector
<
uint64_t
>
out
;
ASSERT_EQ
(
m_sampler
.
GetNextSample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
out
.
push_back
(
*
it
);
}
}
ASSERT_EQ
(
3
,
out
.
size
());
ASSERT_EQ
(
m_sampler
.
GetNextSample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
}
TEST_F
(
MindDataTestDistributedSampler
,
TestThreeShards
)
{
// num samples to draw.
uint64_t
num_samples
=
2
;
// create sampler with replacement = true
DistributedSampler
m_sampler
(
num_samples
,
3
,
2
,
false
,
0
,
false
);
DummyRandomAccessOp
dummyRandomAccessOp
(
num_samples
);
m_sampler
.
HandshakeRandomAccessOp
(
&
dummyRandomAccessOp
);
std
::
unique_ptr
<
DataBuffer
>
db
;
TensorRow
row
;
std
::
vector
<
uint64_t
>
out
;
ASSERT_EQ
(
m_sampler
.
GetNextSample
(
&
db
),
Status
::
OK
());
db
->
PopRow
(
&
row
);
for
(
const
auto
&
t
:
row
)
{
for
(
auto
it
=
t
->
begin
<
uint64_t
>
();
it
!=
t
->
end
<
uint64_t
>
();
it
++
)
{
out
.
push_back
(
*
it
);
}
}
ASSERT_EQ
(
0
,
out
.
size
());
ASSERT_EQ
(
m_sampler
.
GetNextSample
(
&
db
),
Status
::
OK
());
ASSERT_EQ
(
db
->
eoe
(),
true
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录