Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4e2b1eec
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4e2b1eec
编写于
7月 31, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 31, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3119 Delete parameter block_reader for MindDataset
Merge pull request !3119 from dengyutao/remove_blockreader
上级
b55e5e2c
388e43f7
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
45 addition
and
412 deletion
+45
-412
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
+0
-2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc
...inddata/dataset/engine/datasetops/source/mindrecord_op.cc
+18
-71
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h
...minddata/dataset/engine/datasetops/source/mindrecord_op.h
+2
-21
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
+1
-30
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
+21
-194
mindspore/dataset/core/validator_helpers.py
mindspore/dataset/core/validator_helpers.py
+0
-4
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+2
-18
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+1
-1
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+0
-2
tests/ut/cpp/dataset/mind_record_op_test.cc
tests/ut/cpp/dataset/mind_record_op_test.cc
+0
-1
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
+0
-25
tests/ut/python/dataset/test_minddataset.py
tests/ut/python/dataset/test_minddataset.py
+0
-43
未找到文件。
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
浏览文件 @
4e2b1eec
...
...
@@ -672,8 +672,6 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumMindRecordWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"block_reader"
&&
ToBool
(
value
)
==
true
)
{
(
void
)
builder
->
SetBlockReader
();
}
else
if
(
key
==
"sampler"
)
{
int
num_padded
=
0
;
if
(
!
args
[
"num_padded"
].
is_none
())
{
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
4e2b1eec
...
...
@@ -51,7 +51,6 @@ MindRecordOp::Builder::Builder() : build_dataset_file_({}) {
build_num_mind_record_workers_
=
kDefaultMindRecordWorkers
;
build_rows_per_buffer_
=
cfg
->
rows_per_buffer
();
build_op_connector_queue_size_
=
cfg
->
op_connector_size
();
build_block_reader_
=
false
;
builder_num_workers_
=
0
;
build_num_padded_
=
0
;
build_sample_
=
nullptr
;
...
...
@@ -69,10 +68,10 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
if
(
build_num_padded_
>
0
)
{
sample_json
=
ToJson
(
build_sample_
);
}
new_mind_record_op
=
std
::
make_shared
<
MindRecordOp
>
(
build_num_mind_record_workers_
,
build_rows_per_buffer_
,
build_dataset_file_
,
build_load_dataset
_
,
build_op_connector_queue_size_
,
build_columns_to_load_
,
build_operators_
,
build_block_reader_
,
build_num_padde
d_
,
sample_json
,
build_sample_bytes_
);
new_mind_record_op
=
std
::
make_shared
<
MindRecordOp
>
(
build_num_mind_record_workers_
,
build_rows_per_buffer_
,
build_dataset_file
_
,
build_load_dataset_
,
build_op_connector_queue_size_
,
build_columns_to_loa
d_
,
build_operators_
,
build_num_padded_
,
sample_json
,
build_sample_bytes_
);
RETURN_IF_NOT_OK
(
new_mind_record_op
->
Init
());
*
ptr
=
std
::
move
(
new_mind_record_op
);
...
...
@@ -113,9 +112,8 @@ mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) {
MindRecordOp
::
MindRecordOp
(
int32_t
num_mind_record_workers
,
int32_t
rows_per_buffer
,
std
::
vector
<
std
::
string
>
dataset_file
,
bool
load_dataset
,
int32_t
op_connector_queue_size
,
const
std
::
vector
<
std
::
string
>
&
columns_to_load
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
const
bool
&
block_reader
,
int64_t
num_padded
,
const
mindrecord
::
json
&
sample_json
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
sample_bytes
)
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
int64_t
num_padded
,
const
mindrecord
::
json
&
sample_json
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
sample_bytes
)
:
ParallelOp
(
num_mind_record_workers
,
op_connector_queue_size
),
rows_per_buffer_
(
rows_per_buffer
),
dataset_file_
(
dataset_file
),
...
...
@@ -123,27 +121,21 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
columns_to_load_
(
columns_to_load
),
operators_
(
operators
),
num_mind_record_workers_
(
num_mind_record_workers
),
block_reader_
(
block_reader
),
num_rows_
(
0
),
buffers_needed_
(
0
),
buf_cnt_
(
0
),
ended_worker_
(
0
),
buffer_water_mark_
(
0
),
num_padded_
(
num_padded
),
sample_json_
(
sample_json
),
sample_bytes_
(
sample_bytes
)
{
io_blk_queues_
.
Init
(
num_workers_
,
op_connector_queue_size
);
if
(
!
block_reader_
)
return
;
for
(
int32_t
i
=
0
;
i
<
num_workers_
;
++
i
)
{
block_buffer_
.
emplace_back
(
std
::
make_unique
<
std
::
vector
<
ShardTuple
>>
(
std
::
vector
<
ShardTuple
>
{}));
}
}
// Private helper method to encapsulate some common construction/reset tasks
Status
MindRecordOp
::
Init
()
{
shard_reader_
=
std
::
make_unique
<
ShardReader
>
();
auto
rc
=
shard_reader_
->
Open
(
dataset_file_
,
load_dataset_
,
num_mind_record_workers_
,
columns_to_load_
,
operators_
,
block_reader_
,
num_padded_
);
num_padded_
);
CHECK_FAIL_RETURN_UNEXPECTED
(
rc
==
MSRStatus
::
SUCCESS
,
"MindRecordOp init failed. Error message: "
+
ErrnoToMessage
(
rc
));
...
...
@@ -264,23 +256,6 @@ Status MindRecordOp::WorkerEntry(int32_t worker_id) {
}
RETURN_IF_NOT_OK
(
GetBufferFromReader
(
&
fetched_buffer
,
buffer_id
,
worker_id
));
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
worker_id
,
std
::
move
(
fetched_buffer
)));
if
(
!
block_reader_
)
{
RETURN_IF_NOT_OK
(
io_blk_queues_
[
worker_id
]
->
PopFront
(
&
io_block
));
continue
;
}
// update block-reader buffer
block_buffer_
[
buffer_id
%
num_workers_
]
->
clear
();
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mtx_block_reader_
);
if
(
buffer_id
==
buffer_water_mark_
)
{
buffer_water_mark_
++
;
while
(
block_set_
.
count
(
buffer_water_mark_
)
>
0
)
(
void
)
block_set_
.
erase
(
buffer_water_mark_
++
);
}
else
{
(
void
)
block_set_
.
insert
(
buffer_id
);
}
}
cv_reader_
.
notify_one
();
RETURN_IF_NOT_OK
(
io_blk_queues_
[
worker_id
]
->
PopFront
(
&
io_block
));
}
RETURN_STATUS_UNEXPECTED
(
"Unexpected nullptr received in worker"
);
...
...
@@ -291,23 +266,16 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
*
fetched_buffer
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id
,
DataBuffer
::
kDeBFlagNone
);
std
::
unique_ptr
<
TensorQTable
>
tensor_table
=
std
::
make_unique
<
TensorQTable
>
();
for
(
int32_t
i
=
0
;
i
<
rows_per_buffer_
;
++
i
)
{
ShardTuple
tupled_buffer
;
mindrecord
::
TaskType
task_type
=
mindrecord
::
TaskType
::
kCommonTask
;
if
(
block_reader_
)
{
if
(
i
>=
block_buffer_
[
buffer_id
%
num_workers_
]
->
size
())
break
;
tupled_buffer
=
block_buffer_
[
buffer_id
%
num_workers_
]
->
at
(
i
);
}
else
{
int32_t
row_id
=
buffer_id
*
rows_per_buffer_
+
i
;
auto
rc
=
shard_reader_
->
GetNextById
(
row_id
,
worker_id
);
task_type
=
rc
.
first
;
tupled_buffer
=
rc
.
second
;
if
(
task_type
==
mindrecord
::
TaskType
::
kPaddedTask
)
{
TensorRow
tensor_row
;
RETURN_IF_NOT_OK
(
LoadTensorRow
(
&
tensor_row
,
{},
mindrecord
::
json
(),
task_type
));
tensor_table
->
push_back
(
std
::
move
(
tensor_row
));
}
if
(
tupled_buffer
.
empty
())
break
;
int32_t
row_id
=
buffer_id
*
rows_per_buffer_
+
i
;
auto
rc
=
shard_reader_
->
GetNextById
(
row_id
,
worker_id
);
auto
task_type
=
rc
.
first
;
auto
tupled_buffer
=
rc
.
second
;
if
(
task_type
==
mindrecord
::
TaskType
::
kPaddedTask
)
{
TensorRow
tensor_row
;
RETURN_IF_NOT_OK
(
LoadTensorRow
(
&
tensor_row
,
{},
mindrecord
::
json
(),
task_type
));
tensor_table
->
push_back
(
std
::
move
(
tensor_row
));
}
if
(
tupled_buffer
.
empty
())
break
;
if
(
task_type
==
mindrecord
::
TaskType
::
kCommonTask
)
{
for
(
const
auto
&
tupled_row
:
tupled_buffer
)
{
std
::
vector
<
uint8_t
>
columns_blob
=
std
::
get
<
0
>
(
tupled_row
);
...
...
@@ -396,21 +364,6 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
return
Status
::
OK
();
}
Status
MindRecordOp
::
FetchBlockBuffer
(
const
int32_t
&
buffer_id
)
{
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mtx_block_reader_
);
cv_reader_
.
wait
(
lck
,
[
buffer_id
,
this
]
{
return
buffer_id
<
buffer_water_mark_
+
num_workers_
;
});
}
for
(
int32_t
i
=
0
;
i
<
rows_per_buffer_
;
i
++
)
{
// Block reader does NOT care about argument
auto
rc
=
shard_reader_
->
GetNextById
(
i
,
i
);
ShardTuple
tuple_buffer
=
rc
.
second
;
if
(
tuple_buffer
.
empty
())
break
;
block_buffer_
[
buffer_id
%
num_workers_
]
->
push_back
(
std
::
move
(
tuple_buffer
));
}
return
Status
::
OK
();
}
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
...
...
@@ -423,7 +376,6 @@ Status MindRecordOp::operator()() {
while
(
true
)
{
// each iterator is 1 epoch
for
(
int32_t
i
=
0
;
i
<
buffers_needed_
;
++
i
)
{
if
(
block_reader_
)
RETURN_IF_NOT_OK
(
FetchBlockBuffer
(
i
));
std
::
vector
<
int64_t
>
keys
(
1
,
i
);
RETURN_IF_NOT_OK
(
io_blk_queues_
[
buf_cnt_
++
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
))));
...
...
@@ -455,12 +407,7 @@ Status MindRecordOp::operator()() {
Status
MindRecordOp
::
Reset
()
{
RETURN_IF_NOT_OK
(
ParallelOp
::
Reset
());
// Call our super class reset first.
if
(
block_reader_
)
{
shard_reader_
->
Reset
();
buffer_water_mark_
=
0
;
}
else
{
shard_reader_
->
ShuffleTask
();
}
shard_reader_
->
ShuffleTask
();
shard_reader_wait_post_
.
Set
();
return
Status
::
OK
();
...
...
@@ -473,7 +420,7 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
RETURN_IF_NOT_OK
(
io_blk_queues_
.
Register
(
tree_
->
AllTasks
()));
RETURN_IF_NOT_OK
(
shard_reader_wait_post_
.
Register
(
tree_
->
AllTasks
()));
if
(
shard_reader_
->
Launch
(
!
block_reader_
)
==
MSRStatus
::
FAILED
)
{
if
(
shard_reader_
->
Launch
(
true
)
==
MSRStatus
::
FAILED
)
{
RETURN_STATUS_UNEXPECTED
(
"MindRecordOp launch failed."
);
}
// Launch main workers that load DataBuffers by reading all images
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h
浏览文件 @
4e2b1eec
...
...
@@ -94,11 +94,6 @@ class MindRecordOp : public ParallelOp {
return
*
this
;
}
Builder
&
SetBlockReader
()
{
build_block_reader_
=
true
;
return
*
this
;
}
Builder
&
SetLoadDataset
(
bool
load_dataset
)
{
build_load_dataset_
=
load_dataset
;
return
*
this
;
...
...
@@ -132,7 +127,6 @@ class MindRecordOp : public ParallelOp {
bool
build_load_dataset_
;
std
::
vector
<
std
::
string
>
build_columns_to_load_
;
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
build_operators_
;
bool
build_block_reader_
;
int64_t
build_num_padded_
;
py
::
handle
build_sample_
;
std
::
map
<
std
::
string
,
std
::
string
>
build_sample_bytes_
;
...
...
@@ -148,9 +142,8 @@ class MindRecordOp : public ParallelOp {
// @param operators - ShardOperators for Shuffle, Category, Sample
MindRecordOp
(
int32_t
num_mind_record_workers
,
int32_t
rows_per_buffer
,
std
::
vector
<
std
::
string
>
dataset_file
,
bool
load_dataset
,
int32_t
op_connector_queue_size
,
const
std
::
vector
<
std
::
string
>
&
columns_to_load
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
const
bool
&
block_reader
,
int64_t
num_padded_
,
const
mindrecord
::
json
&
sample_json
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
sample_bytes_
);
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
int64_t
num_padded_
,
const
mindrecord
::
json
&
sample_json
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
sample_bytes_
);
// Destructor
~
MindRecordOp
()
override
;
...
...
@@ -206,8 +199,6 @@ class MindRecordOp : public ParallelOp {
// Getter method
std
::
vector
<
std
::
string
>
columns_to_load
()
const
{
return
columns_to_load_
;
}
bool
block_reader
()
const
{
return
block_reader_
;
}
bool
load_dataset
()
const
{
return
load_dataset_
;
}
Status
Init
();
...
...
@@ -232,8 +223,6 @@ class MindRecordOp : public ParallelOp {
Status
LoadTensorRow
(
TensorRow
*
tensor_row
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
mindrecord
::
json
&
columns_json
,
const
mindrecord
::
TaskType
task_type
);
Status
FetchBlockBuffer
(
const
int32_t
&
buffer_id
);
// Private function for computing the assignment of the column name map.
// @return - Status
Status
ComputeColMap
()
override
;
...
...
@@ -244,12 +233,10 @@ class MindRecordOp : public ParallelOp {
std
::
vector
<
std
::
string
>
columns_to_load_
;
// Columns to load from dataset
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
operators_
;
// ShardOperators to use
int32_t
num_mind_record_workers_
;
// number of workers to be spawned by ShardReader
bool
block_reader_
;
// block reader switch
int32_t
buffers_needed_
;
// Counter for the buffers that were fetched
int64_t
buf_cnt_
;
// Buffer counter
int32_t
num_rows_
;
// One more than the last row id in the range for this cache
std
::
atomic
<
int32_t
>
ended_worker_
;
std
::
atomic
<
int32_t
>
buffer_water_mark_
;
int64_t
num_padded_
;
mindrecord
::
json
sample_json_
;
...
...
@@ -263,12 +250,6 @@ class MindRecordOp : public ParallelOp {
WaitPost
shard_reader_wait_post_
;
QueueList
<
std
::
unique_ptr
<
IOBlock
>>
io_blk_queues_
;
// For block reader
std
::
mutex
mtx_block_reader_
;
std
::
condition_variable
cv_reader_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
ShardTuple
>>>
block_buffer_
;
std
::
unordered_set
<
int32_t
>
block_set_
;
std
::
mutex
ended_worker_mutex_
;
};
}
// namespace dataset
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
浏览文件 @
4e2b1eec
...
...
@@ -63,7 +63,6 @@ using ROW_GROUP_BRIEF =
using
TASK_RETURN_CONTENT
=
std
::
pair
<
MSRStatus
,
std
::
pair
<
TaskType
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>>>
;
const
int
kNumBatchInMap
=
1000
;
// iterator buffer size in row-reader mode
const
int
kNumPageInBuffer
=
16
;
// page buffer size in block-reader mode
class
ShardReader
{
public:
...
...
@@ -77,12 +76,10 @@ class ShardReader {
/// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \param[in] block_reader block-reader mode if true, otherwise row-reader mode
/// \return MSRStatus the status of MSRStatus
MSRStatus
Open
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
,
int
n_consumer
=
4
,
const
std
::
vector
<
std
::
string
>
&
selected_columns
=
{},
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
=
{},
const
bool
&
block_reader
=
false
,
const
int
num_padded
=
0
);
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
=
{},
const
int
num_padded
=
0
);
/// \brief open files and initialize reader, python API
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
...
...
@@ -189,10 +186,6 @@ class ShardReader {
std
::
pair
<
TaskType
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>>
GetNextById
(
const
int64_t
&
task_id
,
const
int32_t
&
consumer_id
);
/// \brief return a batch in block-reader mode, given that one is ready
/// \return a batch of images and image data
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
GetBlockNext
();
/// \brief return a batch, given that one is ready, python API
/// \return a batch of images and image data
std
::
vector
<
std
::
tuple
<
std
::
vector
<
std
::
vector
<
uint8_t
>>
,
pybind11
::
object
>>
GetNextPy
();
...
...
@@ -242,9 +235,6 @@ class ShardReader {
/// \brief populate one row by task list in row-reader mode
MSRStatus
ConsumerByRow
(
int
consumer_id
);
/// \brief populate one row by task list in block-reader mode
MSRStatus
ConsumerByBlock
(
int
consumer_id
);
/// \brief get offset address of images within page
std
::
vector
<
std
::
vector
<
uint64_t
>>
GetImageOffset
(
int
group_id
,
int
shard_id
,
const
std
::
pair
<
std
::
string
,
std
::
string
>
&
criteria
=
{
""
,
""
});
...
...
@@ -262,10 +252,6 @@ class ShardReader {
const
std
::
pair
<
std
::
string
,
std
::
string
>
&
criteria
=
{
""
,
""
});
/// \brief create task list in block-reader mode
MSRStatus
CreateTasksByBlock
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
);
/// \brief create category-applied task list
MSRStatus
CreateTasksByCategory
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
);
...
...
@@ -290,15 +276,10 @@ class ShardReader {
/// \brief read one row by one task
TASK_RETURN_CONTENT
ConsumerOneTask
(
int
task_id
,
uint32_t
consumer_id
);
/// \brief get one row from buffer in block-reader mode
std
::
shared_ptr
<
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>>
GetRowFromBuffer
(
int
bufId
,
int
rowId
);
/// \brief get labels from binary file
std
::
pair
<
MSRStatus
,
std
::
vector
<
json
>>
GetLabelsFromBinaryFile
(
int
shard_id
,
const
std
::
vector
<
std
::
string
>
&
columns
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
label_offsets
);
MSRStatus
ReadBlob
(
const
int
&
shard_id
,
const
uint64_t
&
page_offset
,
const
int
&
page_length
,
const
int
&
buf_id
);
/// \brief get classes in one shard
void
GetClassesInShard
(
sqlite3
*
db
,
int
shard_id
,
const
std
::
string
sql
,
std
::
set
<
std
::
string
>
&
categories
);
...
...
@@ -349,16 +330,6 @@ class ShardReader {
// map of delivery
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>>>
delivery_map_
;
// Delivery/Iterator mode end
// Block reader mode begin
bool
block_reader_
;
// block-reader mode
int
row_id_
;
// row id in one page
int
num_blocks_
;
// number of pages
// raw data page
std
::
vector
<
std
::
shared_ptr
<
std
::
pair
<
std
::
vector
<
std
::
vector
<
uint64_t
>>
,
std
::
vector
<
json
>>>>
delivery_block_
;
std
::
unordered_set
<
int
>
delivery_block_set_
;
// set of delivered pages
std
::
vector
<
std
::
vector
<
uint8_t
>>
buf_
;
// page buffer
// Block reader mode end
};
}
// namespace mindrecord
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
浏览文件 @
4e2b1eec
...
...
@@ -43,9 +43,6 @@ ShardReader::ShardReader() {
page_size_
=
0
;
header_size_
=
0
;
num_rows_
=
0
;
row_id_
=
0
;
num_blocks_
=
0
;
block_reader_
=
false
;
num_padded_
=
0
;
}
...
...
@@ -855,8 +852,7 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
MSRStatus
ShardReader
::
Open
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
,
int
n_consumer
,
const
std
::
vector
<
std
::
string
>
&
selected_columns
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
const
bool
&
block_reader
,
int
num_padded
)
{
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
int
num_padded
)
{
// Open file and set header by ShardReader
auto
ret
=
Init
(
file_paths
,
load_dataset
);
if
(
SUCCESS
!=
ret
)
{
...
...
@@ -890,19 +886,8 @@ MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool loa
operators_
=
operators
;
if
(
block_reader
)
{
block_reader_
=
true
;
if
(
Open
()
==
FAILED
)
{
return
FAILED
;
}
delivery_block_
=
std
::
vector
<
std
::
shared_ptr
<
std
::
pair
<
std
::
vector
<
std
::
vector
<
uint64_t
>>
,
std
::
vector
<
json
>>>>
(
kNumPageInBuffer
,
std
::
shared_ptr
<
std
::
pair
<
std
::
vector
<
std
::
vector
<
uint64_t
>>
,
std
::
vector
<
json
>>>
{});
buf_
=
std
::
vector
<
std
::
vector
<
uint8_t
>>
(
kNumPageInBuffer
,
std
::
vector
<
uint8_t
>
(
page_size_
));
}
else
{
block_reader_
=
false
;
if
(
Open
(
n_consumer
)
==
FAILED
)
{
return
FAILED
;
}
if
(
Open
(
n_consumer
)
==
FAILED
)
{
return
FAILED
;
}
return
SUCCESS
;
}
...
...
@@ -960,29 +945,13 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
}
for
(
int
x
=
0
;
x
<
n_consumer_
;
++
x
)
{
if
(
block_reader_
)
{
thread_set_
[
x
]
=
std
::
thread
(
&
ShardReader
::
ConsumerByBlock
,
this
,
x
);
}
else
{
thread_set_
[
x
]
=
std
::
thread
(
&
ShardReader
::
ConsumerByRow
,
this
,
x
);
}
thread_set_
[
x
]
=
std
::
thread
(
&
ShardReader
::
ConsumerByRow
,
this
,
x
);
}
MS_LOG
(
INFO
)
<<
"Launch read thread successfully."
;
return
SUCCESS
;
}
MSRStatus
ShardReader
::
CreateTasksByBlock
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
)
{
CheckIfColumnInIndex
(
selected_columns_
);
for
(
const
auto
&
rg
:
row_group_summary
)
{
auto
shard_id
=
std
::
get
<
0
>
(
rg
);
auto
group_id
=
std
::
get
<
1
>
(
rg
);
auto
n_Rows
=
std
::
get
<
3
>
(
rg
);
tasks_
.
InsertTask
(
TaskType
::
kCommonTask
,
shard_id
,
group_id
,
std
::
vector
<
uint64_t
>
{
n_Rows
},
json
{});
}
return
SUCCESS
;
}
MSRStatus
ShardReader
::
CreateTasksByCategory
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
)
{
CheckIfColumnInIndex
(
selected_columns_
);
...
...
@@ -1070,47 +1039,39 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i
MSRStatus
ShardReader
::
CreateTasks
(
const
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
&
row_group_summary
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
)
{
if
(
block_reader_
)
{
if
(
SUCCESS
!=
CreateTasksByBlock
(
row_group_summary
,
operators
))
{
int
category_operator
=
-
1
;
for
(
uint32_t
i
=
0
;
i
<
operators
.
size
();
++
i
)
{
const
auto
&
op
=
operators
[
i
];
if
(
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
))
{
category_operator
=
static_cast
<
int
>
(
i
);
break
;
}
}
if
(
-
1
==
category_operator
)
{
if
(
SUCCESS
!=
CreateTasksByRow
(
row_group_summary
,
operators
))
{
return
FAILED
;
}
}
else
{
int
category_operator
=
-
1
;
for
(
uint32_t
i
=
0
;
i
<
operators
.
size
();
++
i
)
{
const
auto
&
op
=
operators
[
i
];
if
(
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
))
{
category_operator
=
static_cast
<
int
>
(
i
);
break
;
if
(
num_padded_
>
0
)
{
for
(
int
i
=
0
;
i
<
num_padded_
;
++
i
)
{
tasks_
.
InsertTask
(
TaskType
::
kPaddedTask
,
0
,
0
,
{},
json
());
}
}
if
(
-
1
==
category_operator
)
{
if
(
SUCCESS
!=
CreateTasksByRow
(
row_group_summary
,
operators
))
{
return
FAILED
;
}
if
(
num_padded_
>
0
)
{
for
(
int
i
=
0
;
i
<
num_padded_
;
++
i
)
{
tasks_
.
InsertTask
(
TaskType
::
kPaddedTask
,
0
,
0
,
{},
json
());
}
}
}
else
{
if
(
SUCCESS
!=
CreateTasksByCategory
(
row_group_summary
,
operators
[
category_operator
]))
{
return
FAILED
;
}
}
else
{
if
(
SUCCESS
!=
CreateTasksByCategory
(
row_group_summary
,
operators
[
category_operator
]))
{
return
FAILED
;
}
}
for
(
uint32_t
operator_no
=
0
;
operator_no
<
operators
.
size
();
operator_no
++
)
{
const
auto
&
op
=
operators
[
operator_no
];
if
(
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
))
continue
;
if
(
block_reader_
&&
std
::
dynamic_pointer_cast
<
ShardShuffle
>
(
op
))
continue
;
if
(
SUCCESS
!=
(
*
op
)(
tasks_
))
{
return
FAILED
;
}
}
if
(
tasks_
.
permutation_
.
empty
())
tasks_
.
MakePerm
();
num_rows_
=
block_reader_
?
tasks_
.
SizeOfRows
()
:
tasks_
.
Size
();
num_blocks_
=
block_reader_
?
tasks_
.
Size
()
:
0
;
num_rows_
=
tasks_
.
Size
();
MS_LOG
(
INFO
)
<<
"Total rows is "
<<
num_rows_
;
return
SUCCESS
;
}
...
...
@@ -1207,140 +1168,10 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) {
}
}
MSRStatus
ShardReader
::
ReadBlob
(
const
int
&
shard_id
,
const
uint64_t
&
page_offset
,
const
int
&
page_length
,
const
int
&
buf_id
)
{
auto
&
io_seekg
=
file_streams_
[
shard_id
]
->
seekg
(
page_offset
,
std
::
ios
::
beg
);
if
(
!
io_seekg
.
good
()
||
io_seekg
.
fail
()
||
io_seekg
.
bad
())
{
MS_LOG
(
ERROR
)
<<
"File seekg failed"
;
file_streams_
[
shard_id
]
->
close
();
return
FAILED
;
}
auto
&
io_read
=
file_streams_
[
shard_id
]
->
read
(
reinterpret_cast
<
char
*>
(
&
buf_
[
buf_id
][
0
]),
page_length
);
if
(
!
io_read
.
good
()
||
io_read
.
fail
()
||
io_read
.
bad
())
{
MS_LOG
(
ERROR
)
<<
"File read failed"
;
file_streams_
[
shard_id
]
->
close
();
return
FAILED
;
}
return
SUCCESS
;
}
MSRStatus
ShardReader
::
ConsumerByBlock
(
int
consumer_id
)
{
// Set thread name
#if !defined(_WIN32) && !defined(_WIN64)
auto
thread_id
=
kThreadName
+
std
::
to_string
(
consumer_id
);
prctl
(
PR_SET_NAME
,
common
::
SafeCStr
(
thread_id
),
0
,
0
,
0
);
#endif
// Loop forever
for
(;;)
{
int
task_id
=
0
;
// Get next task ID
task_id
=
task_id_
++
;
// All tasks are done, either quit or repeat again
if
(
task_id
>=
num_blocks_
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mtx_delivery_
);
cv_delivery_
.
wait
(
lck
,
[
this
]
{
return
interrupt_
||
task_id_
<
num_blocks_
;
});
if
(
interrupt_
)
{
return
SUCCESS
;
}
continue
;
}
// Pick up task from task list
auto
task
=
tasks_
.
GetTaskByID
(
tasks_
.
permutation_
[
task_id
]);
auto
shard_id
=
std
::
get
<
0
>
(
std
::
get
<
1
>
(
task
));
auto
group_id
=
std
::
get
<
1
>
(
std
::
get
<
1
>
(
task
));
auto
row_group_brief
=
ReadRowGroupBrief
(
group_id
,
shard_id
,
selected_columns_
);
if
(
SUCCESS
!=
std
::
get
<
0
>
(
row_group_brief
))
{
return
FAILED
;
}
auto
page_length
=
std
::
get
<
2
>
(
row_group_brief
);
auto
page_offset
=
std
::
get
<
3
>
(
row_group_brief
);
MS_LOG
(
DEBUG
)
<<
"Block task "
<<
task_id
<<
tasks_
.
permutation_
[
task_id
]
<<
", shard "
<<
shard_id
<<
", group "
<<
group_id
<<
", page length "
<<
page_length
<<
", page offset "
<<
page_offset
;
// Deliver block data to output map
auto
offset_and_labels
=
std
::
make_pair
(
std
::
get
<
4
>
(
row_group_brief
),
std
::
get
<
5
>
(
row_group_brief
));
int
deliver_id
=
deliver_id_
;
// Hanging if maximum map size exceeded otherwise, set batch data in buffer
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mtx_delivery_
);
cv_delivery_
.
wait
(
lck
,
[
task_id
,
this
]
{
return
interrupt_
||
task_id
<
deliver_id_
+
kNumPageInBuffer
;
});
if
(
interrupt_
)
{
return
SUCCESS
;
}
}
auto
buf_id
=
task_id
%
kNumPageInBuffer
;
delivery_block_
[
buf_id
]
=
std
::
make_shared
<
std
::
pair
<
std
::
vector
<
std
::
vector
<
uint64_t
>>
,
std
::
vector
<
json
>>>
(
offset_and_labels
);
// Read blob
if
(
ReadBlob
(
shard_id
,
page_offset
,
page_length
,
buf_id
)
!=
SUCCESS
)
{
return
FAILED
;
}
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mtx_delivery_
);
delivery_block_set_
.
insert
(
task_id
);
}
cv_iterator_
.
notify_one
();
}
}
std
::
shared_ptr
<
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>>
ShardReader
::
GetRowFromBuffer
(
int
buf_id
,
int
rowId
)
{
auto
&
blob_page
=
buf_
[
buf_id
];
auto
&
offsets
=
(
*
delivery_block_
[
buf_id
]).
first
;
auto
&
labels
=
(
*
delivery_block_
[
buf_id
]).
second
;
auto
&
addr_start
=
offsets
[
rowId
][
0
];
auto
&
addr_end
=
offsets
[
rowId
][
1
];
std
::
vector
<
uint8_t
>
images
(
blob_page
.
begin
()
+
addr_start
,
blob_page
.
begin
()
+
addr_end
);
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
batch
;
batch
.
emplace_back
(
std
::
move
(
images
),
std
::
move
(
labels
[
rowId
]));
return
std
::
make_shared
<
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>>
(
std
::
move
(
batch
));
}
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
ShardReader
::
GetBlockNext
()
{
if
(
deliver_id_
>=
num_blocks_
)
{
return
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
();
}
if
(
row_id_
==
0
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mtx_delivery_
);
cv_iterator_
.
wait
(
lck
,
[
this
]
{
return
interrupt_
||
(
delivery_block_set_
.
count
(
deliver_id_
)
>
0
);
});
if
(
interrupt_
)
{
return
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
();
}
}
auto
buf_id
=
deliver_id_
%
kNumPageInBuffer
;
auto
res
=
GetRowFromBuffer
(
buf_id
,
row_id_
);
row_id_
++
;
if
(
row_id_
==
(
*
delivery_block_
[
buf_id
]).
first
.
size
())
{
row_id_
=
0
;
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mtx_delivery_
);
delivery_block_set_
.
erase
(
deliver_id_
++
);
}
cv_delivery_
.
notify_all
();
}
return
*
res
;
}
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
ShardReader
::
GetNext
()
{
if
(
interrupt_
)
{
return
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
();
}
if
(
block_reader_
)
return
GetBlockNext
();
if
(
deliver_id_
>=
static_cast
<
int
>
(
tasks_
.
Size
()))
{
return
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
();
}
...
...
@@ -1366,9 +1197,6 @@ std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardRe
if
(
interrupt_
)
{
return
std
::
make_pair
(
TaskType
::
kCommonTask
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
());
}
if
(
block_reader_
)
{
return
std
::
make_pair
(
TaskType
::
kCommonTask
,
GetBlockNext
());
}
const
auto
&
ret
=
ConsumerOneTask
(
task_id
,
consumer_id
);
if
(
SUCCESS
!=
ret
.
first
)
{
return
std
::
make_pair
(
TaskType
::
kCommonTask
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
());
...
...
@@ -1423,7 +1251,6 @@ void ShardReader::Reset() {
}
void
ShardReader
::
ShuffleTask
()
{
if
(
block_reader_
)
return
;
// exist shuffle and distributed sampler in ops, skip shuffle
bool
has_sharding
=
false
;
for
(
const
auto
&
op
:
operators_
)
{
...
...
mindspore/dataset/core/validator_helpers.py
浏览文件 @
4e2b1eec
...
...
@@ -300,7 +300,6 @@ def check_padding_options(param_dict):
"""
columns_list
=
param_dict
.
get
(
'columns_list'
)
block_reader
=
param_dict
.
get
(
'block_reader'
)
padded_sample
,
num_padded
=
param_dict
.
get
(
'padded_sample'
),
param_dict
.
get
(
'num_padded'
)
if
padded_sample
is
not
None
:
if
num_padded
is
None
:
...
...
@@ -312,9 +311,6 @@ def check_padding_options(param_dict):
for
column
in
columns_list
:
if
column
not
in
padded_sample
:
raise
ValueError
(
"padded_sample cannot match columns_list."
)
if
block_reader
:
raise
RuntimeError
(
"block_reader and padded_sample cannot be specified at the same time."
)
if
padded_sample
is
None
and
num_padded
is
not
None
:
raise
RuntimeError
(
"num_padded is specified but padded_sample is not."
)
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
4e2b1eec
...
...
@@ -2795,7 +2795,6 @@ class MindDataset(MappableDataset):
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
block_reader (bool, optional): Whether read data by block mode (default=False).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, sampler is exclusive
with shuffle and block_reader). Support list: SubsetRandomSampler,
...
...
@@ -2810,13 +2809,12 @@ class MindDataset(MappableDataset):
Raises:
ValueError: If num_shards is specified but shard_id is None.
ValueError: If shard_id is specified but num_shards is None.
ValueError: If block reader is true but partition is specified.
"""
@
check_minddataset
def
__init__
(
self
,
dataset_file
,
columns_list
=
None
,
num_parallel_workers
=
None
,
shuffle
=
None
,
num_shards
=
None
,
shard_id
=
None
,
block_reader
=
False
,
sampler
=
None
,
padded_sample
=
None
,
sampler
=
None
,
padded_sample
=
None
,
num_padded
=
None
,
num_samples
=
None
):
super
().
__init__
(
num_parallel_workers
)
if
isinstance
(
dataset_file
,
list
):
...
...
@@ -2828,14 +2826,7 @@ class MindDataset(MappableDataset):
self
.
shuffle_option
=
shuffle
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
if
block_reader
is
True
and
num_shards
is
not
None
:
raise
ValueError
(
"block_reader not allowed true when use partitions"
)
if
block_reader
is
True
and
shuffle
is
True
:
raise
ValueError
(
"block_reader not allowed true when use shuffle"
)
if
block_reader
is
True
:
if
shuffle
is
False
:
logger
.
warning
(
"WARN: global shuffle is not used."
)
if
sampler
is
not
None
:
...
...
@@ -2846,15 +2837,9 @@ class MindDataset(MappableDataset):
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
self
.
num_samples
=
num_samples
# sampler exclusive
if
block_reader
is
True
and
sampler
is
not
None
:
raise
ValueError
(
"block_reader not allowed true when use sampler"
)
if
num_padded
is
None
:
num_padded
=
0
self
.
block_reader
=
block_reader
self
.
padded_sample
=
padded_sample
self
.
num_padded
=
num_padded
...
...
@@ -2873,7 +2858,6 @@ class MindDataset(MappableDataset):
args
[
"columns_list"
]
=
self
.
columns_list
args
[
"shuffle_option"
]
=
self
.
shuffle_option
args
[
"num_samples"
]
=
self
.
num_samples
args
[
"block_reader"
]
=
self
.
block_reader
args
[
"num_padded"
]
=
self
.
num_padded
args
[
"padded_sample"
]
=
padded_sample
args
[
"sampler"
]
=
self
.
sampler
...
...
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
4e2b1eec
...
...
@@ -279,7 +279,7 @@ def create_node(node):
sampler
=
construct_sampler
(
node
.
get
(
'sampler'
))
pyobj
=
pyclass
(
node
[
'dataset_file'
],
node
.
get
(
'columns_list'
),
node
.
get
(
'num_parallel_workers'
),
node
.
get
(
'seed'
),
node
.
get
(
'num_shards'
),
node
.
get
(
'shard_id'
),
node
.
get
(
'block_reader'
),
sampler
)
node
.
get
(
'shard_id'
),
sampler
)
elif
dataset_op
==
'TFRecordDataset'
:
pyobj
=
pyclass
(
node
[
'dataset_files'
],
node
.
get
(
'schema'
),
node
.
get
(
'column_list'
),
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
4e2b1eec
...
...
@@ -273,7 +273,6 @@ def check_minddataset(method):
nreq_param_int
=
[
'num_samples'
,
'num_parallel_workers'
,
'seed'
,
'num_shards'
,
'shard_id'
,
'num_padded'
]
nreq_param_list
=
[
'columns_list'
]
nreq_param_bool
=
[
'block_reader'
]
nreq_param_dict
=
[
'padded_sample'
]
dataset_file
=
param_dict
.
get
(
'dataset_file'
)
...
...
@@ -287,7 +286,6 @@ def check_minddataset(method):
validate_dataset_param_value
(
nreq_param_int
,
param_dict
,
int
)
validate_dataset_param_value
(
nreq_param_list
,
param_dict
,
list
)
validate_dataset_param_value
(
nreq_param_bool
,
param_dict
,
bool
)
validate_dataset_param_value
(
nreq_param_dict
,
param_dict
,
dict
)
check_sampler_shuffle_shard_options
(
param_dict
)
...
...
tests/ut/cpp/dataset/mind_record_op_test.cc
浏览文件 @
4e2b1eec
...
...
@@ -435,7 +435,6 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) {
.
SetLoadDataset
(
true
)
.
SetRowsPerBuffer
(
3
)
.
SetNumMindRecordWorkers
(
4
)
.
SetBlockReader
()
.
SetColumnsToLoad
(
column_list
);
rc
=
builder
.
Build
(
&
my_mindrecord_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
...
...
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
浏览文件 @
4e2b1eec
...
...
@@ -94,31 +94,6 @@ TEST_F(TestShardReader, TestShardReaderSample) {
dataset
.
Close
();
}
TEST_F
(
TestShardReader
,
TestShardReaderBlock
)
{
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test read imageNet with block way"
);
std
::
string
file_name
=
"./imagenet.shard01"
;
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"label"
};
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
3
));
ShardReader
dataset
;
const
bool
kBlockReader
=
true
;
dataset
.
Open
({
file_name
},
true
,
4
,
column_list
,
ops
,
kBlockReader
);
dataset
.
Launch
();
while
(
true
)
{
auto
x
=
dataset
.
GetBlockNext
();
if
(
x
.
empty
())
break
;
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
MS_LOG
(
INFO
)
<<
"key: "
<<
item
.
key
()
<<
", value: "
<<
item
.
value
().
dump
();
}
}
}
dataset
.
Finish
();
dataset
.
Close
();
}
TEST_F
(
TestShardReader
,
TestShardReaderEasy
)
{
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test read imageNet"
);
std
::
string
file_name
=
"./imagenet.shard01"
;
...
...
tests/ut/python/dataset/test_minddataset.py
浏览文件 @
4e2b1eec
...
...
@@ -591,49 +591,6 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file):
assert
num_iter
==
18
def
test_cv_minddataset_blockreader_tutorial
(
add_and_remove_cv_file
):
"""tutorial for cv minddataset."""
columns_list
=
[
"data"
,
"label"
]
num_readers
=
4
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
block_reader
=
True
)
assert
data_set
.
get_dataset_size
()
==
10
repeat_num
=
2
data_set
=
data_set
.
repeat
(
repeat_num
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- block reader repeat tow {} -----------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
num_iter
+=
1
assert
num_iter
==
20
def
test_cv_minddataset_blockreader_some_field_not_in_index_tutorial
(
add_and_remove_cv_file
):
"""tutorial for cv minddataset."""
columns_list
=
[
"id"
,
"data"
,
"label"
]
num_readers
=
4
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
shuffle
=
False
,
block_reader
=
True
)
assert
data_set
.
get_dataset_size
()
==
10
repeat_num
=
2
data_set
=
data_set
.
repeat
(
repeat_num
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- block reader repeat tow {} -----------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[id]: {} ----------------------------"
.
format
(
item
[
"id"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
num_iter
+=
1
assert
num_iter
==
20
def
test_cv_minddataset_reader_file_list
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录