Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7aac5080
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7aac5080
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!514 [MD] add Pk Sampler in minddataset
Merge pull request !514 from liyong126/mindrecord_pk_sampler_lee
上级
5d467874
f1542a90
master
r0.2
r0.3
r0.5
r0.6
r0.7
v0.7.0-beta
v0.6.0-beta
v0.5.0-beta
v0.3.1-alpha
v0.3.0-alpha
v0.2.0-alpha
无相关合并请求
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
540 addition
and
81 deletion
+540
-81
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+18
-2
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
...e/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
+3
-2
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
...re/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
+2
-1
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
+2
-0
mindspore/ccsrc/mindrecord/include/shard_category.h
mindspore/ccsrc/mindrecord/include/shard_category.h
+22
-2
mindspore/ccsrc/mindrecord/include/shard_operator.h
mindspore/ccsrc/mindrecord/include/shard_operator.h
+2
-0
mindspore/ccsrc/mindrecord/include/shard_pk_sample.h
mindspore/ccsrc/mindrecord/include/shard_pk_sample.h
+49
-0
mindspore/ccsrc/mindrecord/include/shard_reader.h
mindspore/ccsrc/mindrecord/include/shard_reader.h
+13
-3
mindspore/ccsrc/mindrecord/include/shard_sample.h
mindspore/ccsrc/mindrecord/include/shard_sample.h
+3
-0
mindspore/ccsrc/mindrecord/include/shard_shuffle.h
mindspore/ccsrc/mindrecord/include/shard_shuffle.h
+2
-1
mindspore/ccsrc/mindrecord/include/shard_task.h
mindspore/ccsrc/mindrecord/include/shard_task.h
+3
-1
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+151
-24
mindspore/ccsrc/mindrecord/meta/shard_category.cc
mindspore/ccsrc/mindrecord/meta/shard_category.cc
+22
-3
mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc
mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc
+46
-0
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
+18
-0
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
+19
-11
mindspore/ccsrc/mindrecord/meta/shard_task.cc
mindspore/ccsrc/mindrecord/meta/shard_task.cc
+28
-12
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+7
-4
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+2
-0
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
+52
-0
tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt
...t/data/mindrecord/testImageNetData/annotation_sampler.txt
+10
-0
tests/ut/python/dataset/test_minddataset_sampler.py
tests/ut/python/dataset/test_minddataset_sampler.py
+65
-14
tests/ut/python/dataset/test_serdes_dataset.py
tests/ut/python/dataset/test_serdes_dataset.py
+1
-1
未找到文件。
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
7aac5080
...
...
@@ -60,6 +60,7 @@
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_pk_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
...
...
@@ -152,9 +153,14 @@ void bindDatasetOps(py::module *m) {
});
(
void
)
py
::
class_
<
MindRecordOp
,
DatasetOp
,
std
::
shared_ptr
<
MindRecordOp
>>
(
*
m
,
"MindRecordOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
path
)
{
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
path
,
const
py
::
object
&
sampler
)
{
int64_t
count
=
0
;
THROW_IF_ERROR
(
MindRecordOp
::
CountTotalRows
(
path
,
&
count
));
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
op
;
if
(
py
::
hasattr
(
sampler
,
"_create_for_minddataset"
))
{
auto
create
=
sampler
.
attr
(
"_create_for_minddataset"
);
op
=
create
().
cast
<
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
();
}
THROW_IF_ERROR
(
MindRecordOp
::
CountTotalRows
(
path
,
op
,
&
count
));
return
count
;
});
...
...
@@ -435,6 +441,16 @@ void bindSamplerOps(py::module *m) {
(
void
)
py
::
class_
<
mindrecord
::
ShardSample
,
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardSample
>>
(
*
m
,
"MindrecordSubsetRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
int64_t
>
,
uint32_t
>
(),
py
::
arg
(
"indices"
),
py
::
arg
(
"seed"
)
=
GetSeed
());
(
void
)
py
::
class_
<
mindrecord
::
ShardPkSample
,
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardPkSample
>>
(
*
m
,
"MindrecordPkSampler"
)
.
def
(
py
::
init
([](
int64_t
kVal
,
bool
shuffle
)
{
if
(
shuffle
==
true
)
{
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
"label"
,
kVal
,
std
::
numeric_limits
<
int64_t
>::
max
(),
GetSeed
());
}
else
{
return
std
::
make_shared
<
mindrecord
::
ShardPkSample
>
(
"label"
,
kVal
);
}
}));
(
void
)
py
::
class_
<
WeightedRandomSampler
,
Sampler
,
std
::
shared_ptr
<
WeightedRandomSampler
>>
(
*
m
,
"WeightedRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
double
>
,
int64_t
,
bool
>
(),
py
::
arg
(
"weights"
),
py
::
arg
(
"numSamples"
),
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
7aac5080
...
...
@@ -655,9 +655,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
return
Status
::
OK
();
}
Status
MindRecordOp
::
CountTotalRows
(
const
std
::
string
dataset_path
,
int64_t
*
count
)
{
Status
MindRecordOp
::
CountTotalRows
(
const
std
::
string
dataset_path
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
int64_t
*
count
)
{
std
::
unique_ptr
<
ShardReader
>
shard_reader
=
std
::
make_unique
<
ShardReader
>
();
MSRStatus
rc
=
shard_reader
->
CountTotalRows
(
dataset_path
,
count
);
MSRStatus
rc
=
shard_reader
->
CountTotalRows
(
dataset_path
,
op
,
count
);
if
(
rc
==
MSRStatus
::
FAILED
)
{
RETURN_STATUS_UNEXPECTED
(
"MindRecordOp count total rows failed."
);
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
浏览文件 @
7aac5080
...
...
@@ -171,7 +171,8 @@ class MindRecordOp : public ParallelOp {
int32_t
num_rows
()
const
{
return
num_rows_
;
}
// Getter method
static
Status
CountTotalRows
(
const
std
::
string
dataset_path
,
int64_t
*
count
);
static
Status
CountTotalRows
(
const
std
::
string
dataset_path
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
int64_t
*
count
);
// Getter method
int32_t
rows_per_buffer
()
const
{
return
rows_per_buffer_
;
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
浏览文件 @
7aac5080
...
...
@@ -72,6 +72,8 @@ enum ShardType {
enum
SamplerType
{
kCustomTopNSampler
,
kCustomTopPercentSampler
,
kSubsetRandomSampler
,
kPKSampler
};
enum
ShuffleType
{
kShuffleCategory
,
kShuffleSample
};
const
double
kEpsilon
=
1e-7
;
const
int
kThreadNumber
=
14
;
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/include/shard_category.h
浏览文件 @
7aac5080
...
...
@@ -17,6 +17,8 @@
#ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_
#define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_
#include <algorithm>
#include <limits>
#include <string>
#include <utility>
#include <vector>
...
...
@@ -26,16 +28,34 @@ namespace mindspore {
namespace
mindrecord
{
class
ShardCategory
:
public
ShardOperator
{
public:
explicit
ShardCategory
(
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
categories
);
explicit
ShardCategory
(
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
categories
,
int64_t
num_elements
=
std
::
numeric_limits
<
int64_t
>::
max
(),
bool
replacement
=
false
);
ShardCategory
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
=
std
::
numeric_limits
<
int64_t
>::
max
(),
bool
replacement
=
false
);
~
ShardCategory
()
override
{};
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
get_categories
()
const
;
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
get_categories
()
const
{
return
categories_
;
}
const
std
::
string
GetCategoryField
()
const
{
return
category_field_
;
}
int64_t
GetNumElements
()
const
{
return
num_elements_
;
}
int64_t
GetNumCategories
()
const
{
return
num_categories_
;
}
bool
GetReplacement
()
const
{
return
replacement_
;
}
MSRStatus
execute
(
ShardTask
&
tasks
)
override
;
int64_t
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
override
;
private:
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
categories_
;
std
::
string
category_field_
;
int64_t
num_elements_
;
int64_t
num_categories_
;
bool
replacement_
;
};
}
// namespace mindrecord
}
// namespace mindspore
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/include/shard_operator.h
浏览文件 @
7aac5080
...
...
@@ -43,6 +43,8 @@ class ShardOperator {
virtual
MSRStatus
execute
(
ShardTask
&
tasks
)
=
0
;
virtual
MSRStatus
suf_execute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
virtual
int64_t
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
{
return
-
1
;
}
};
}
// namespace mindrecord
}
// namespace mindspore
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/include/shard_pk_sample.h
0 → 100644
浏览文件 @
7aac5080
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_
#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_shuffle.h"
#include "mindrecord/include/shard_category.h"
namespace
mindspore
{
namespace
mindrecord
{
class
ShardPkSample
:
public
ShardCategory
{
public:
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
);
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
);
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
,
uint32_t
seed
);
~
ShardPkSample
()
override
{};
MSRStatus
suf_execute
(
ShardTask
&
tasks
)
override
;
private:
bool
shuffle_
;
std
::
shared_ptr
<
ShardShuffle
>
shuffle_op_
;
};
}
// namespace mindrecord
}
// namespace mindspore
#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/include/shard_reader.h
浏览文件 @
7aac5080
...
...
@@ -115,9 +115,10 @@ class ShardReader {
/// \brief get the number of rows in database
/// \param[in] file_path the path of ONE file, any file in dataset is fine
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object
/// \param[out] count # of rows
/// \return MSRStatus the status of MSRStatus
MSRStatus
CountTotalRows
(
const
std
::
string
&
file_path
,
int64_t
*
count
);
MSRStatus
CountTotalRows
(
const
std
::
string
&
file_path
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
int64_t
*
count
);
/// \brief shuffle task with incremental seed
/// \return void
...
...
@@ -197,6 +198,9 @@ class ShardReader {
/// \brief get NLP flag
bool
get_nlp_flag
();
/// \brief get all classes
MSRStatus
GetAllClasses
(
const
std
::
string
&
category_field
,
std
::
set
<
std
::
string
>
&
categories
);
protected:
/// \brief sqlite call back function
static
int
SelectCallback
(
void
*
p_data
,
int
num_fields
,
char
**
p_fields
,
char
**
p_col_names
);
...
...
@@ -249,8 +253,8 @@ class ShardReader {
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
);
/// \brief create category-applied task list
int
CreateTasksByCategory
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
);
MSRStatus
CreateTasksByCategory
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
);
/// \brief create task list in row-reader mode
MSRStatus
CreateTasksByRow
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
...
...
@@ -284,6 +288,12 @@ class ShardReader {
MSRStatus
ReadBlob
(
const
int
&
shard_id
,
const
uint64_t
&
page_offset
,
const
int
&
page_length
,
const
int
&
buf_id
);
/// \brief get classes in one shard
void
GetClassesInShard
(
sqlite3
*
db
,
int
shard_id
,
const
std
::
string
sql
,
std
::
set
<
std
::
string
>
&
categories
);
/// \brief get number of classes
int64_t
GetNumClasses
(
const
std
::
string
&
file_path
,
const
std
::
string
&
category_field
);
protected:
uint64_t
header_size_
;
// header size
uint64_t
page_size_
;
// page size
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/include/shard_sample.h
浏览文件 @
7aac5080
...
...
@@ -41,8 +41,11 @@ class ShardSample : public ShardOperator {
const
std
::
pair
<
int
,
int
>
get_partitions
()
const
;
MSRStatus
execute
(
ShardTask
&
tasks
)
override
;
MSRStatus
suf_execute
(
ShardTask
&
tasks
)
override
;
int64_t
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
override
;
private:
int
numerator_
;
int
denominator_
;
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/include/shard_shuffle.h
浏览文件 @
7aac5080
...
...
@@ -24,7 +24,7 @@ namespace mindspore {
namespace
mindrecord
{
class
ShardShuffle
:
public
ShardOperator
{
public:
explicit
ShardShuffle
(
uint32_t
seed
=
0
);
explicit
ShardShuffle
(
uint32_t
seed
=
0
,
ShuffleType
shuffle_type
=
kShuffleCategory
);
~
ShardShuffle
()
override
{};
...
...
@@ -32,6 +32,7 @@ class ShardShuffle : public ShardOperator {
private:
uint32_t
shuffle_seed_
;
ShuffleType
shuffle_type_
;
};
}
// namespace mindrecord
}
// namespace mindspore
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/include/shard_task.h
浏览文件 @
7aac5080
...
...
@@ -41,7 +41,9 @@ class ShardTask {
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
get_task_by_id
(
size_t
id
);
static
ShardTask
Combine
(
std
::
vector
<
ShardTask
>
&
category_tasks
);
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
get_random_task
();
static
ShardTask
Combine
(
std
::
vector
<
ShardTask
>
&
category_tasks
,
bool
replacement
,
int64_t
num_elements
);
uint32_t
categories
=
1
;
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
7aac5080
...
...
@@ -315,6 +315,43 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
return
ConvertLabelToJson
(
labels
,
fs
,
offsets
,
shard_id
,
columns
,
column_values
);
}
MSRStatus
ShardReader
::
GetAllClasses
(
const
std
::
string
&
category_field
,
std
::
set
<
std
::
string
>
&
categories
)
{
auto
ret
=
ShardIndexGenerator
::
GenerateFieldName
(
std
::
make_pair
(
column_schema_id_
[
category_field
],
category_field
));
if
(
SUCCESS
!=
ret
.
first
)
{
return
FAILED
;
}
std
::
string
sql
=
"SELECT DISTINCT "
+
ret
.
second
+
" FROM INDEXES"
;
std
::
vector
<
std
::
thread
>
threads
=
std
::
vector
<
std
::
thread
>
(
shard_count_
);
for
(
int
x
=
0
;
x
<
shard_count_
;
x
++
)
{
threads
[
x
]
=
std
::
thread
(
&
ShardReader
::
GetClassesInShard
,
this
,
database_paths_
[
x
],
x
,
sql
,
std
::
ref
(
categories
));
}
for
(
int
x
=
0
;
x
<
shard_count_
;
x
++
)
{
threads
[
x
].
join
();
}
return
SUCCESS
;
}
void
ShardReader
::
GetClassesInShard
(
sqlite3
*
db
,
int
shard_id
,
const
std
::
string
sql
,
std
::
set
<
std
::
string
>
&
categories
)
{
if
(
nullptr
==
db
)
{
return
;
}
std
::
vector
<
std
::
vector
<
std
::
string
>>
columns
;
char
*
errmsg
=
nullptr
;
int
ret
=
sqlite3_exec
(
db
,
common
::
SafeCStr
(
sql
),
SelectCallback
,
&
columns
,
&
errmsg
);
if
(
ret
!=
SQLITE_OK
)
{
sqlite3_free
(
errmsg
);
sqlite3_close
(
db
);
MS_LOG
(
ERROR
)
<<
"Error in select sql statement, sql:"
<<
common
::
SafeCStr
(
sql
)
<<
", error: "
<<
errmsg
;
return
;
}
MS_LOG
(
INFO
)
<<
"Get"
<<
static_cast
<
int
>
(
columns
.
size
())
<<
" records from shard "
<<
shard_id
<<
" index."
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
columns
.
size
());
++
i
)
{
categories
.
emplace
(
columns
[
i
][
0
]);
}
}
ROW_GROUPS
ShardReader
::
ReadAllRowGroup
(
std
::
vector
<
std
::
string
>
&
columns
)
{
std
::
string
fields
=
"ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
offsets
(
shard_count_
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
{});
...
...
@@ -667,11 +704,64 @@ MSRStatus ShardReader::Finish() {
return
SUCCESS
;
}
MSRStatus
ShardReader
::
CountTotalRows
(
const
std
::
string
&
file_path
,
int64_t
*
count
)
{
int64_t
ShardReader
::
GetNumClasses
(
const
std
::
string
&
file_path
,
const
std
::
string
&
category_field
)
{
ShardHeader
sh
=
ShardHeader
();
if
(
sh
.
Build
(
file_path
)
==
FAILED
)
{
return
-
1
;
}
auto
header
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
auto
file_paths
=
header
->
get_shard_addresses
();
auto
shard_count
=
file_paths
.
size
();
auto
index_fields
=
header
->
get_fields
();
std
::
map
<
std
::
string
,
int64_t
>
map_schema_id_fields
;
for
(
auto
&
field
:
index_fields
)
{
map_schema_id_fields
[
field
.
second
]
=
field
.
first
;
}
auto
ret
=
ShardIndexGenerator
::
GenerateFieldName
(
std
::
make_pair
(
map_schema_id_fields
[
category_field
],
category_field
));
if
(
SUCCESS
!=
ret
.
first
)
{
return
-
1
;
}
std
::
string
sql
=
"SELECT DISTINCT "
+
ret
.
second
+
" FROM INDEXES"
;
std
::
vector
<
std
::
thread
>
threads
=
std
::
vector
<
std
::
thread
>
(
shard_count
);
std
::
set
<
std
::
string
>
categories
;
for
(
int
x
=
0
;
x
<
shard_count
;
x
++
)
{
sqlite3
*
db
=
nullptr
;
int
rc
=
sqlite3_open_v2
(
common
::
SafeCStr
(
file_paths
[
x
]
+
".db"
),
&
db
,
SQLITE_OPEN_READONLY
,
nullptr
);
if
(
SQLITE_OK
!=
rc
)
{
MS_LOG
(
ERROR
)
<<
"Can't open database, error: "
<<
sqlite3_errmsg
(
db
);
return
-
1
;
}
threads
[
x
]
=
std
::
thread
(
&
ShardReader
::
GetClassesInShard
,
this
,
db
,
x
,
sql
,
std
::
ref
(
categories
));
}
for
(
int
x
=
0
;
x
<
shard_count
;
x
++
)
{
threads
[
x
].
join
();
}
return
categories
.
size
();
}
MSRStatus
ShardReader
::
CountTotalRows
(
const
std
::
string
&
file_path
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
int64_t
*
count
)
{
if
(
Init
(
file_path
)
==
FAILED
)
{
return
FAILED
;
}
*
count
=
num_rows_
;
int64_t
num_samples
=
num_rows_
;
if
(
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
))
{
auto
category_op
=
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
);
std
::
string
category_field
=
category_op
->
GetCategoryField
();
auto
num_classes
=
GetNumClasses
(
file_path
,
category_field
);
num_samples
=
category_op
->
GetNumSamples
(
num_rows_
,
num_classes
);
}
else
if
(
std
::
dynamic_pointer_cast
<
ShardSample
>
(
op
))
{
num_samples
=
op
->
GetNumSamples
(
num_rows_
,
0
);
}
else
{
}
if
(
-
1
==
num_samples
)
{
MS_LOG
(
ERROR
)
<<
"Failed to get dataset size."
;
return
FAILED
;
}
*
count
=
num_samples
;
return
SUCCESS
;
}
...
...
@@ -793,6 +883,8 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
thread_set_
[
x
]
=
std
::
thread
(
&
ShardReader
::
ConsumerByRow
,
this
,
x
);
}
}
MS_LOG
(
INFO
)
<<
"Launch read thread successfully."
;
return
SUCCESS
;
}
...
...
@@ -828,44 +920,67 @@ MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int,
return
SUCCESS
;
}
int
ShardReader
::
CreateTasksByCategory
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
)
{
MSRStatus
ShardReader
::
CreateTasksByCategory
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
)
{
vector
<
std
::
string
>
columns
=
GetAllColumns
();
CheckIfColumnInIndex
(
columns
);
int
category_operator
=
-
1
;
for
(
uint32_t
i
=
0
;
i
<
operators
.
size
();
++
i
)
{
const
auto
&
op
=
operators
[
i
];
if
(
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
))
category_operator
=
static_cast
<
int
>
(
i
);
auto
category_op
=
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
);
auto
categories
=
category_op
->
get_categories
();
int64_t
num_elements
=
category_op
->
GetNumElements
();
if
(
num_elements
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Parameter num_element is not positive"
;
return
FAILED
;
}
if
(
categories
.
empty
()
==
true
)
{
std
::
string
category_field
=
category_op
->
GetCategoryField
();
int64_t
num_categories
=
category_op
->
GetNumCategories
();
if
(
num_categories
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Parameter num_categories is not positive"
;
return
FAILED
;
}
std
::
set
<
std
::
string
>
categories_set
;
auto
ret
=
GetAllClasses
(
category_field
,
categories_set
);
if
(
SUCCESS
!=
ret
)
{
return
FAILED
;
}
int
i
=
0
;
for
(
auto
it
=
categories_set
.
begin
();
it
!=
categories_set
.
end
()
&&
i
<
num_categories
;
++
it
)
{
categories
.
emplace_back
(
category_field
,
*
it
);
i
++
;
}
}
if
(
category_operator
==
-
1
)
return
category_operator
;
auto
categories
=
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
operators
[
category_operator
])
->
get_categories
();
// Generate task list, a task will create a batch
std
::
vector
<
ShardTask
>
categoryTasks
(
categories
.
size
());
for
(
uint32_t
categoryNo
=
0
;
categoryNo
<
categories
.
size
();
++
categoryNo
)
{
int
category_index
=
0
;
for
(
const
auto
&
rg
:
row_group_summary
)
{
if
(
category_index
>=
num_elements
)
break
;
auto
shard_id
=
std
::
get
<
0
>
(
rg
);
auto
group_id
=
std
::
get
<
1
>
(
rg
);
auto
details
=
ReadRowGroupCriteria
(
group_id
,
shard_id
,
categories
[
categoryNo
],
columns
);
if
(
SUCCESS
!=
std
::
get
<
0
>
(
details
))
{
return
-
2
;
return
FAILED
;
}
auto
offsets
=
std
::
get
<
4
>
(
details
);
auto
number_of_rows
=
offsets
.
size
();
for
(
uint32_t
iStart
=
0
;
iStart
<
number_of_rows
;
iStart
+=
1
)
{
categoryTasks
[
categoryNo
].
InsertTask
(
shard_id
,
group_id
,
std
::
get
<
4
>
(
details
)[
iStart
],
std
::
get
<
5
>
(
details
)[
iStart
]);
if
(
category_index
<
num_elements
)
{
categoryTasks
[
categoryNo
].
InsertTask
(
shard_id
,
group_id
,
std
::
get
<
4
>
(
details
)[
iStart
],
std
::
get
<
5
>
(
details
)[
iStart
]);
category_index
++
;
}
}
}
MS_LOG
(
INFO
)
<<
"Category #"
<<
categoryNo
<<
" has "
<<
categoryTasks
[
categoryNo
].
Size
()
<<
" tasks"
;
}
tasks_
=
ShardTask
::
Combine
(
categoryTasks
);
return
category_operator
;
tasks_
=
ShardTask
::
Combine
(
categoryTasks
,
category_op
->
GetReplacement
(),
num_elements
);
if
(
SUCCESS
!=
(
*
category_op
)(
tasks_
))
{
return
FAILED
;
}
return
SUCCESS
;
}
MSRStatus
ShardReader
::
CreateTasksByRow
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
...
...
@@ -896,14 +1011,26 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i
MSRStatus
ShardReader
::
CreateTasks
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
)
{
if
(
block_reader_
)
{
CreateTasksByBlock
(
row_group_summary
,
operators
);
if
(
SUCCESS
!=
CreateTasksByBlock
(
row_group_summary
,
operators
))
{
return
FAILED
;
}
}
else
{
int
category_operator
=
CreateTasksByCategory
(
row_group_summary
,
operators
);
if
(
category_operator
==
-
1
)
{
CreateTasksByRow
(
row_group_summary
,
operators
);
int
category_operator
=
-
1
;
for
(
uint32_t
i
=
0
;
i
<
operators
.
size
();
++
i
)
{
const
auto
&
op
=
operators
[
i
];
if
(
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
))
{
category_operator
=
static_cast
<
int
>
(
i
);
break
;
}
}
if
(
category_operator
==
-
2
)
{
return
FAILED
;
if
(
-
1
==
category_operator
)
{
if
(
SUCCESS
!=
CreateTasksByRow
(
row_group_summary
,
operators
))
{
return
FAILED
;
}
}
else
{
if
(
SUCCESS
!=
CreateTasksByCategory
(
row_group_summary
,
operators
[
category_operator
]))
{
return
FAILED
;
}
}
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/meta/shard_category.cc
浏览文件 @
7aac5080
...
...
@@ -18,11 +18,30 @@
namespace
mindspore
{
namespace
mindrecord
{
ShardCategory
::
ShardCategory
(
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
categories
)
:
categories_
(
categories
)
{}
ShardCategory
::
ShardCategory
(
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
categories
,
int64_t
num_elements
,
bool
replacement
)
:
categories_
(
categories
),
category_field_
(
""
),
num_elements_
(
num_elements
),
num_categories_
(
0
),
replacement_
(
replacement
)
{}
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
ShardCategory
::
get_categories
()
const
{
return
categories_
;
}
ShardCategory
::
ShardCategory
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
,
bool
replacement
)
:
categories_
({}),
category_field_
(
category_field
),
num_elements_
(
num_elements
),
num_categories_
(
num_categories
),
replacement_
(
replacement
)
{}
MSRStatus
ShardCategory
::
execute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
int64_t
ShardCategory
::
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
{
if
(
dataset_size
==
0
)
return
dataset_size
;
if
(
dataset_size
>
0
&&
num_categories_
>
0
&&
num_elements_
>
0
)
{
return
std
::
min
(
num_categories_
,
num_classes
)
*
num_elements_
;
}
return
-
1
;
}
}
// namespace mindrecord
}
// namespace mindspore
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc
0 → 100644
浏览文件 @
7aac5080
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindrecord/include/shard_pk_sample.h"
using
mindspore
::
LogStream
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
MsLogLevel
::
ERROR
;
namespace
mindspore
{
namespace
mindrecord
{
ShardPkSample
::
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
)
:
ShardCategory
(
category_field
,
num_elements
,
std
::
numeric_limits
<
int64_t
>::
max
(),
true
),
shuffle_
(
false
)
{}
ShardPkSample
::
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
)
:
ShardCategory
(
category_field
,
num_elements
,
num_categories
,
true
),
shuffle_
(
false
)
{}
ShardPkSample
::
ShardPkSample
(
const
std
::
string
&
category_field
,
int64_t
num_elements
,
int64_t
num_categories
,
uint32_t
seed
)
:
ShardCategory
(
category_field
,
num_elements
,
num_categories
,
true
),
shuffle_
(
true
)
{
shuffle_op_
=
std
::
make_shared
<
ShardShuffle
>
(
seed
,
kShuffleSample
);
// do shuffle and replacement
}
MSRStatus
ShardPkSample
::
suf_execute
(
ShardTask
&
tasks
)
{
if
(
shuffle_
==
true
)
{
if
(
SUCCESS
!=
(
*
shuffle_op_
)(
tasks
))
{
return
FAILED
;
}
}
return
SUCCESS
;
}
}
// namespace mindrecord
}
// namespace mindspore
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
浏览文件 @
7aac5080
...
...
@@ -56,6 +56,24 @@ ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed)
shuffle_op_
=
std
::
make_shared
<
ShardShuffle
>
(
seed
);
}
int64_t
ShardSample
::
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
{
if
(
sampler_type_
==
kCustomTopNSampler
)
{
return
no_of_samples_
;
}
if
(
sampler_type_
==
kCustomTopPercentSampler
)
{
if
(
dataset_size
%
denominator_
==
0
)
{
return
dataset_size
/
denominator_
*
numerator_
;
}
else
{
return
dataset_size
/
denominator_
*
numerator_
+
1
;
}
}
if
(
sampler_type_
==
kSubsetRandomSampler
)
{
return
indices_
.
size
();
}
return
-
1
;
}
const
std
::
pair
<
int
,
int
>
ShardSample
::
get_partitions
()
const
{
if
(
numerator_
==
1
&&
denominator_
>
1
)
{
return
std
::
pair
<
int
,
int
>
(
denominator_
,
partition_id_
);
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
浏览文件 @
7aac5080
...
...
@@ -20,25 +20,33 @@
namespace
mindspore
{
namespace
mindrecord
{
ShardShuffle
::
ShardShuffle
(
uint32_t
seed
)
:
shuffle_seed_
(
seed
)
{}
ShardShuffle
::
ShardShuffle
(
uint32_t
seed
,
ShuffleType
shuffle_type
)
:
shuffle_seed_
(
seed
),
shuffle_type_
(
shuffle_type
)
{}
MSRStatus
ShardShuffle
::
execute
(
ShardTask
&
tasks
)
{
if
(
tasks
.
categories
<
1
)
{
return
FAILED
;
}
uint32_t
individual_size
=
tasks
.
Size
()
/
tasks
.
categories
;
std
::
vector
<
std
::
vector
<
int
>>
new_permutations
(
tasks
.
categories
,
std
::
vector
<
int
>
(
individual_size
));
for
(
uint32_t
i
=
0
;
i
<
tasks
.
categories
;
i
++
)
{
for
(
uint32_t
j
=
0
;
j
<
individual_size
;
j
++
)
new_permutations
[
i
][
j
]
=
static_cast
<
int
>
(
j
);
std
::
shuffle
(
new_permutations
[
i
].
begin
(),
new_permutations
[
i
].
end
(),
std
::
default_random_engine
(
shuffle_seed_
));
}
shuffle_seed_
++
;
tasks
.
permutation_
.
clear
();
for
(
uint32_t
j
=
0
;
j
<
individual_size
;
j
++
)
{
if
(
shuffle_type_
==
kShuffleSample
)
{
if
(
tasks
.
permutation_
.
empty
()
==
true
)
{
tasks
.
MakePerm
();
}
std
::
shuffle
(
tasks
.
permutation_
.
begin
(),
tasks
.
permutation_
.
end
(),
std
::
default_random_engine
(
shuffle_seed_
));
}
else
{
// shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
uint32_t
individual_size
=
tasks
.
Size
()
/
tasks
.
categories
;
std
::
vector
<
std
::
vector
<
int
>>
new_permutations
(
tasks
.
categories
,
std
::
vector
<
int
>
(
individual_size
));
for
(
uint32_t
i
=
0
;
i
<
tasks
.
categories
;
i
++
)
{
tasks
.
permutation_
.
push_back
(
new_permutations
[
i
][
j
]
*
static_cast
<
int
>
(
tasks
.
categories
)
+
static_cast
<
int
>
(
i
));
for
(
uint32_t
j
=
0
;
j
<
individual_size
;
j
++
)
new_permutations
[
i
][
j
]
=
static_cast
<
int
>
(
j
);
std
::
shuffle
(
new_permutations
[
i
].
begin
(),
new_permutations
[
i
].
end
(),
std
::
default_random_engine
(
shuffle_seed_
));
}
tasks
.
permutation_
.
clear
();
for
(
uint32_t
j
=
0
;
j
<
individual_size
;
j
++
)
{
for
(
uint32_t
i
=
0
;
i
<
tasks
.
categories
;
i
++
)
{
tasks
.
permutation_
.
push_back
(
new_permutations
[
i
][
j
]
*
static_cast
<
int
>
(
tasks
.
categories
)
+
static_cast
<
int
>
(
i
));
}
}
}
shuffle_seed_
++
;
return
SUCCESS
;
}
}
// namespace mindrecord
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/mindrecord/meta/shard_task.cc
浏览文件 @
7aac5080
...
...
@@ -35,8 +35,6 @@ void ShardTask::InsertTask(int shard_id, int group_id, const std::vector<uint64_
MS_LOG
(
DEBUG
)
<<
"Into insert task, shard_id: "
<<
shard_id
<<
", group_id: "
<<
group_id
<<
", label: "
<<
label
.
dump
()
<<
", size of task_list_: "
<<
task_list_
.
size
()
<<
"."
;
task_list_
.
emplace_back
(
std
::
make_tuple
(
shard_id
,
group_id
),
offset
,
label
);
MS_LOG
(
DEBUG
)
<<
"Out of insert task, shard_id: "
<<
shard_id
<<
", group_id: "
<<
group_id
<<
", label: "
<<
label
.
dump
()
<<
", size of task_list_: "
<<
task_list_
.
size
()
<<
"."
;
}
void
ShardTask
::
InsertTask
(
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
task
)
{
...
...
@@ -44,9 +42,6 @@ void ShardTask::InsertTask(std::tuple<std::tuple<int, int>, std::vector<uint64_t
<<
", group_id: "
<<
std
::
get
<
1
>
(
std
::
get
<
0
>
(
task
))
<<
", label: "
<<
std
::
get
<
2
>
(
task
).
dump
()
<<
", size of task_list_: "
<<
task_list_
.
size
()
<<
"."
;
task_list_
.
push_back
(
std
::
move
(
task
));
MS_LOG
(
DEBUG
)
<<
"Out of insert task, shard_id: "
<<
std
::
get
<
0
>
(
std
::
get
<
0
>
(
task
))
<<
", group_id: "
<<
std
::
get
<
1
>
(
std
::
get
<
0
>
(
task
))
<<
", label: "
<<
std
::
get
<
2
>
(
task
).
dump
()
<<
", size of task_list_: "
<<
task_list_
.
size
()
<<
"."
;
}
void
ShardTask
::
PopBack
()
{
task_list_
.
pop_back
();
}
...
...
@@ -69,18 +64,39 @@ std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_ta
return
task_list_
[
id
];
}
ShardTask
ShardTask
::
Combine
(
std
::
vector
<
ShardTask
>
&
category_tasks
)
{
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
ShardTask
::
get_random_task
()
{
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
uniform_int_distribution
<>
dis
(
0
,
task_list_
.
size
()
-
1
);
return
task_list_
[
dis
(
gen
)];
}
ShardTask
ShardTask
::
Combine
(
std
::
vector
<
ShardTask
>
&
category_tasks
,
bool
replacement
,
int64_t
num_elements
)
{
ShardTask
res
;
if
(
category_tasks
.
empty
())
return
res
;
auto
total_categories
=
category_tasks
.
size
();
res
.
categories
=
static_cast
<
uint32_t
>
(
total_categories
);
auto
minTasks
=
category_tasks
[
0
].
Size
();
for
(
uint32_t
i
=
1
;
i
<
total_categories
;
i
++
)
{
minTasks
=
std
::
min
(
minTasks
,
category_tasks
[
i
].
Size
());
}
for
(
uint32_t
task_no
=
0
;
task_no
<
minTasks
;
task_no
++
)
{
if
(
replacement
==
false
)
{
auto
minTasks
=
category_tasks
[
0
].
Size
();
for
(
uint32_t
i
=
1
;
i
<
total_categories
;
i
++
)
{
minTasks
=
std
::
min
(
minTasks
,
category_tasks
[
i
].
Size
());
}
for
(
uint32_t
task_no
=
0
;
task_no
<
minTasks
;
task_no
++
)
{
for
(
uint32_t
i
=
0
;
i
<
total_categories
;
i
++
)
{
res
.
InsertTask
(
std
::
move
(
category_tasks
[
i
].
get_task_by_id
(
static_cast
<
int
>
(
task_no
))));
}
}
}
else
{
auto
maxTasks
=
category_tasks
[
0
].
Size
();
for
(
uint32_t
i
=
1
;
i
<
total_categories
;
i
++
)
{
maxTasks
=
std
::
max
(
maxTasks
,
category_tasks
[
i
].
Size
());
}
if
(
num_elements
!=
std
::
numeric_limits
<
int64_t
>::
max
())
{
maxTasks
=
static_cast
<
decltype
(
maxTasks
)
>
(
num_elements
);
}
for
(
uint32_t
i
=
0
;
i
<
total_categories
;
i
++
)
{
res
.
InsertTask
(
std
::
move
(
category_tasks
[
i
].
get_task_by_id
(
static_cast
<
int
>
(
task_no
))));
for
(
uint32_t
j
=
0
;
j
<
maxTasks
;
j
++
)
{
res
.
InsertTask
(
category_tasks
[
i
].
get_random_task
());
}
}
}
return
res
;
...
...
This diff is collapsed.
Click to expand it.
mindspore/dataset/engine/datasets.py
浏览文件 @
7aac5080
...
...
@@ -1882,7 +1882,8 @@ class MindDataset(SourceDataset):
block_reader (bool, optional): Whether read data by block mode (default=False).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, sampler is exclusive
with shuffle and block_reader). Support list: SubsetRandomSampler.
with shuffle and block_reader). Support list: SubsetRandomSampler,
PkSampler
Raises:
ValueError: If num_shards is specified but shard_id is None.
...
...
@@ -1915,8 +1916,10 @@ class MindDataset(SourceDataset):
if
block_reader
is
True
:
logger
.
warning
(
"WARN: global shuffle is not used."
)
if
sampler
is
not
None
and
isinstance
(
sampler
,
samplers
.
SubsetRandomSampler
)
is
False
:
raise
ValueError
(
"the sampler is not supported yet."
)
if
sampler
is
not
None
:
if
isinstance
(
sampler
,
samplers
.
SubsetRandomSampler
)
is
False
and
\
isinstance
(
sampler
,
samplers
.
PKSampler
)
is
False
:
raise
ValueError
(
"the sampler is not supported yet."
)
# sampler exclusive
if
block_reader
is
True
and
sampler
is
not
None
:
...
...
@@ -1952,7 +1955,7 @@ class MindDataset(SourceDataset):
Number, number of batches.
"""
num_rows
=
MindRecordOp
.
get_num_rows
(
self
.
dataset_file
)
num_rows
=
MindRecordOp
.
get_num_rows
(
self
.
dataset_file
,
self
.
sampler
)
if
self
.
partitions
is
not
None
and
self
.
partitions
[
0
]
>
0
:
if
num_rows
%
self
.
partitions
[
0
]
==
0
:
num_rows
=
num_rows
//
self
.
partitions
[
0
]
...
...
This diff is collapsed.
Click to expand it.
mindspore/dataset/engine/samplers.py
浏览文件 @
7aac5080
...
...
@@ -184,6 +184,8 @@ class PKSampler(BuiltinSampler):
def
create
(
self
):
return
cde
.
PKSampler
(
self
.
num_val
,
self
.
shuffle
)
def
_create_for_minddataset
(
self
):
return
cde
.
MindrecordPkSampler
(
self
.
num_val
,
self
.
shuffle
)
class
RandomSampler
(
BuiltinSampler
):
"""
...
...
This diff is collapsed.
Click to expand it.
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
浏览文件 @
7aac5080
...
...
@@ -25,6 +25,7 @@
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_pk_sample.h"
#include "mindrecord/include/shard_reader.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h"
...
...
@@ -146,6 +147,57 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
ASSERT_TRUE
(
i
<=
10
);
}
TEST_F
(
TestShardOperator
,
TestShardPkSamplerBasic
)
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test pk sampler"
));
std
::
string
file_name
=
"./imagenet.shard01"
;
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
,
"label"
};
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
));
ShardReader
dataset
;
dataset
.
Open
(
file_name
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
int
i
=
0
;
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
std
::
cout
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
())
<<
std
::
endl
;
i
++
;
}
dataset
.
Finish
();
ASSERT_TRUE
(
i
==
20
);
}
// namespace mindrecord
TEST_F
(
TestShardOperator
,
TestShardPkSamplerNumClass
)
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test pk sampler"
));
std
::
string
file_name
=
"./imagenet.shard01"
;
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
,
"label"
};
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
,
3
,
0
));
ShardReader
dataset
;
dataset
.
Open
(
file_name
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
int
i
=
0
;
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
std
::
cout
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
())
<<
std
::
endl
;
i
++
;
}
dataset
.
Finish
();
ASSERT_TRUE
(
i
==
6
);
}
// namespace mindrecord
TEST_F
(
TestShardOperator
,
TestShardCategory
)
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test read imageNet"
));
...
...
This diff is collapsed.
Click to expand it.
tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt
0 → 100644
浏览文件 @
7aac5080
image_00001.jpg,164
image_00002.jpg,164
image_00003.jpg,164
image_00004.jpg,599
image_00005.jpg,599
image_00006.jpg,599
image_00007.jpg,13
image_00008.jpg,13
image_00009.jpg,13
image_00010.jpg,13
This diff is collapsed.
Click to expand it.
tests/ut/python/dataset/test_minddataset_sampler.py
浏览文件 @
7aac5080
...
...
@@ -46,7 +46,7 @@ def add_and_remove_cv_file():
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
writer
=
FileWriter
(
CV_FILE_NAME
,
FILES_NUM
)
data
=
get_data
(
CV_DIR_NAME
)
data
=
get_data
(
CV_DIR_NAME
,
True
)
cv_schema_json
=
{
"id"
:
{
"type"
:
"int32"
},
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
},
...
...
@@ -61,6 +61,59 @@ def add_and_remove_cv_file():
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_cv_minddataset_pk_sample_basic
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
sampler
=
ds
.
PKSampler
(
2
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
assert
data_set
.
get_dataset_size
()
==
6
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[file_name]:
\
{}------------------------"
.
format
(
""
.
join
([
chr
(
x
)
for
x
in
item
[
"file_name"
]])))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
def
test_cv_minddataset_pk_sample_shuffle
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
sampler
=
ds
.
PKSampler
(
3
,
None
,
True
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
assert
data_set
.
get_dataset_size
()
==
9
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[file_name]:
\
{}------------------------"
.
format
(
""
.
join
([
chr
(
x
)
for
x
in
item
[
"file_name"
]])))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
def
test_cv_minddataset_pk_sample_out_of_range
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
sampler
=
ds
.
PKSampler
(
5
,
None
,
True
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
assert
data_set
.
get_dataset_size
()
==
15
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[file_name]:
\
{}------------------------"
.
format
(
""
.
join
([
chr
(
x
)
for
x
in
item
[
"file_name"
]])))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
def
test_cv_minddataset_subset_random_sample_basic
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
...
...
@@ -69,8 +122,7 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
assert
data_set
.
get_dataset_size
()
==
5
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
...
...
@@ -93,8 +145,7 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file):
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
assert
data_set
.
get_dataset_size
()
==
6
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
...
...
@@ -117,8 +168,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
assert
data_set
.
get_dataset_size
()
==
0
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
...
...
@@ -133,7 +183,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
assert
num_iter
==
0
def
test_cv_minddataset_subset_random_sample_out_range
(
add_and_remove_cv_file
):
def
test_cv_minddataset_subset_random_sample_out_
of_
range
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
...
...
@@ -141,8 +191,7 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file):
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
assert
data_set
.
get_dataset_size
()
==
5
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
...
...
@@ -165,8 +214,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
assert
data_set
.
get_dataset_size
()
==
5
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
...
...
@@ -181,7 +229,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
assert
num_iter
==
5
def
get_data
(
dir_name
):
def
get_data
(
dir_name
,
sampler
=
False
):
"""
usage: get data from imagenet dataset
params:
...
...
@@ -191,7 +239,10 @@ def get_data(dir_name):
if
not
os
.
path
.
isdir
(
dir_name
):
raise
IOError
(
"Directory {} not exists"
.
format
(
dir_name
))
img_dir
=
os
.
path
.
join
(
dir_name
,
"images"
)
ann_file
=
os
.
path
.
join
(
dir_name
,
"annotation.txt"
)
if
sampler
:
ann_file
=
os
.
path
.
join
(
dir_name
,
"annotation_sampler.txt"
)
else
:
ann_file
=
os
.
path
.
join
(
dir_name
,
"annotation.txt"
)
with
open
(
ann_file
,
"r"
)
as
file_reader
:
lines
=
file_reader
.
readlines
()
...
...
This diff is collapsed.
Click to expand it.
tests/ut/python/dataset/test_serdes_dataset.py
浏览文件 @
7aac5080
...
...
@@ -243,7 +243,7 @@ def test_minddataset(add_and_remove_cv_file):
assert
ds1_json
==
ds2_json
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
assert
data_set
.
get_dataset_size
()
==
5
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部