Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6369cf27
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6369cf27
编写于
4月 21, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
!406 added first row crc check for when reading tfrecord files
Merge pull request !406 from Peilin/first-row-crc-check
上级
98fbd30a
9bc2134c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
127 addition
and
11 deletion
+127
-11
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
...re/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
+51
-6
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+13
-4
tests/ut/cpp/dataset/tfReader_op_test.cc
tests/ut/cpp/dataset/tfReader_op_test.cc
+34
-0
tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data
tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data
+0
-0
tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data
tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data
+0
-0
tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data
tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data
+0
-0
tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt
tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt
+1
-0
tests/ut/python/dataset/test_tfreader_op.py
tests/ut/python/dataset/test_tfreader_op.py
+28
-1
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
浏览文件 @
6369cf27
...
...
@@ -42,6 +42,7 @@
#include "dataset/util/status.h"
#include "dataset/util/task_manager.h"
#include "dataset/util/wait_post.h"
#include "utils/system/crc32c.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder()
builder_data_schema_
=
std
::
make_unique
<
DataSchema
>
();
}
bool
ValidateFirstRowCrc
(
const
std
::
string
&
filename
)
{
std
::
ifstream
reader
;
reader
.
open
(
filename
);
if
(
!
reader
)
{
return
false
;
}
// read data
int64_t
record_length
=
0
;
(
void
)
reader
.
read
(
reinterpret_cast
<
char
*>
(
&
record_length
),
static_cast
<
std
::
streamsize
>
(
sizeof
(
int64_t
)));
// read crc from file
uint32_t
masked_crc
=
0
;
(
void
)
reader
.
read
(
reinterpret_cast
<
char
*>
(
&
masked_crc
),
static_cast
<
std
::
streamsize
>
(
sizeof
(
uint32_t
)));
// generate crc from data
uint32_t
generated_crc
=
system
::
Crc32c
::
GetMaskCrc32cValue
(
reinterpret_cast
<
char
*>
(
&
record_length
),
sizeof
(
int64_t
));
return
masked_crc
==
generated_crc
;
}
Status
TFReaderOp
::
Builder
::
ValidateInputs
()
const
{
std
::
string
err_msg
;
err_msg
+=
builder_num_workers_
<=
0
?
"Number of parallel workers is smaller or equal to 0
\n
"
:
""
;
if
(
!
builder_equal_rows_per_shard_
)
{
err_msg
+=
builder_dataset_files_list_
.
size
()
<
static_cast
<
uint32_t
>
(
builder_num_devices_
)
?
"No enough tf_file files provided
\n
"
:
""
;
if
(
builder_num_workers_
<=
0
)
{
err_msg
+=
"Number of parallel workers is smaller or equal to 0
\n
"
;
}
if
(
!
builder_equal_rows_per_shard_
&&
builder_dataset_files_list_
.
size
()
<
static_cast
<
uint32_t
>
(
builder_num_devices_
))
{
err_msg
+=
"Not enough tfrecord files provided
\n
"
;
}
if
(
builder_device_id_
>=
builder_num_devices_
||
builder_num_devices_
<
1
)
{
err_msg
+=
"Wrong sharding configs
\n
"
;
}
err_msg
+=
builder_device_id_
>=
builder_num_devices_
||
builder_num_devices_
<
1
?
"Wrong sharding configs
\n
"
:
""
;
std
::
vector
<
std
::
string
>
invalid_files
(
builder_dataset_files_list_
.
size
());
auto
it
=
std
::
copy_if
(
builder_dataset_files_list_
.
begin
(),
builder_dataset_files_list_
.
end
(),
invalid_files
.
begin
(),
[](
const
std
::
string
&
filename
)
{
return
!
ValidateFirstRowCrc
(
filename
);
});
invalid_files
.
resize
(
std
::
distance
(
invalid_files
.
begin
(),
it
));
if
(
!
invalid_files
.
empty
())
{
err_msg
+=
"The following files either cannot be opened, or are not valid tfrecord files:
\n
"
;
std
::
string
accumulated_filenames
=
std
::
accumulate
(
invalid_files
.
begin
(),
invalid_files
.
end
(),
std
::
string
(
""
),
[](
const
std
::
string
&
accumulated
,
const
std
::
string
&
next
)
{
return
accumulated
+
" "
+
next
+
"
\n
"
;
});
err_msg
+=
accumulated_filenames
;
}
return
err_msg
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
err_msg
);
}
...
...
@@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
RETURN_IF_NOT_OK
(
LoadExample
(
&
tf_file
,
&
new_tensor_table
,
rows_read
));
rows_read
++
;
}
// ignore crc footer
(
void
)
reader
.
ignore
(
static_cast
<
std
::
streamsize
>
(
sizeof
(
int32_t
)));
rows_total
++
;
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
6369cf27
...
...
@@ -926,13 +926,22 @@ class SourceDataset(Dataset):
List, files.
"""
def
flat
(
lists
):
return
list
(
np
.
array
(
lists
).
flatten
())
if
not
isinstance
(
patterns
,
list
):
patterns
=
[
patterns
]
file_list
=
flat
([
glob
.
glob
(
file
,
recursive
=
True
)
for
file
in
patterns
])
file_list
=
[]
unmatched_patterns
=
[]
for
pattern
in
patterns
:
matches
=
[
match
for
match
in
glob
.
glob
(
pattern
,
recursive
=
True
)
if
os
.
path
.
isfile
(
match
)]
if
matches
:
file_list
.
extend
(
matches
)
else
:
unmatched_patterns
.
append
(
pattern
)
if
unmatched_patterns
:
raise
ValueError
(
"The following patterns did not match any files: "
,
unmatched_patterns
)
if
file_list
:
# not empty
return
file_list
raise
ValueError
(
"The list of path names matching the patterns is empty."
)
...
...
tests/ut/cpp/dataset/tfReader_op_test.cc
浏览文件 @
6369cf27
...
...
@@ -697,3 +697,37 @@ TEST_F(MindDataTestTFReaderOp, TestTotalRowsBasic) {
TFReaderOp
::
CountTotalRows
(
&
total_rows
,
filenames
,
729
,
true
);
ASSERT_EQ
(
total_rows
,
60
);
}
TEST_F
(
MindDataTestTFReaderOp
,
TestTFReaderInvalidFiles
)
{
// Start with an empty execution tree
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
valid_file
=
datasets_root_path_
+
"/testTFTestAllTypes/test.data"
;
std
::
string
schema_file
=
datasets_root_path_
+
"/testTFTestAllTypes/datasetSchema.json"
;
std
::
string
invalid_file
=
datasets_root_path_
+
"/testTFTestAllTypes/invalidFile.txt"
;
std
::
string
nonexistent_file
=
"this/file/doesnt/exist"
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
TFReaderOp
::
Builder
builder
;
builder
.
SetDatasetFilesList
({
invalid_file
,
valid_file
,
schema_file
})
.
SetRowsPerBuffer
(
16
)
.
SetNumWorkers
(
16
);
std
::
unique_ptr
<
DataSchema
>
schema
=
std
::
make_unique
<
DataSchema
>
();
schema
->
LoadSchemaFile
(
schema_file
,
{});
builder
.
SetDataSchema
(
std
::
move
(
schema
));
Status
rc
=
builder
.
Build
(
&
my_tfreader_op
);
ASSERT_TRUE
(
!
rc
.
IsOk
());
builder
.
SetDatasetFilesList
({
invalid_file
,
valid_file
,
schema_file
,
nonexistent_file
})
.
SetRowsPerBuffer
(
16
)
.
SetNumWorkers
(
16
);
schema
=
std
::
make_unique
<
DataSchema
>
();
schema
->
LoadSchemaFile
(
schema_file
,
{});
builder
.
SetDataSchema
(
std
::
move
(
schema
));
rc
=
builder
.
Build
(
&
my_tfreader_op
);
ASSERT_TRUE
(
!
rc
.
IsOk
());
}
tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data
浏览文件 @
6369cf27
无法预览此类型文件
tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data
浏览文件 @
6369cf27
无法预览此类型文件
tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data
浏览文件 @
6369cf27
无法预览此类型文件
tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt
0 → 100644
浏览文件 @
6369cf27
this is just a text file, not a valid tfrecord file.
tests/ut/python/dataset/test_tfreader_op.py
浏览文件 @
6369cf27
...
...
@@ -32,7 +32,7 @@ def test_case_tf_shape():
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
ds1
=
ds1
.
batch
(
2
)
for
data
in
ds1
.
create_dict_iterator
():
print
(
data
)
logger
.
info
(
data
)
output_shape
=
ds1
.
output_shapes
()
assert
(
len
(
output_shape
[
-
1
])
==
1
)
...
...
@@ -203,6 +203,32 @@ def test_tf_record_schema_columns_list():
a
=
row
[
"col_sint32"
]
assert
"col_sint32"
in
str
(
info
.
value
)
def
test_case_invalid_files
():
valid_file
=
"../data/dataset/testTFTestAllTypes/test.data"
invalid_file
=
"../data/dataset/testTFTestAllTypes/invalidFile.txt"
files
=
[
invalid_file
,
valid_file
,
SCHEMA_FILE
]
data
=
ds
.
TFRecordDataset
(
files
,
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
row
=
data
.
create_dict_iterator
().
get_next
()
assert
"cannot be opened"
in
str
(
info
.
value
)
assert
"not valid tfrecord files"
in
str
(
info
.
value
)
assert
valid_file
not
in
str
(
info
.
value
)
assert
invalid_file
in
str
(
info
.
value
)
assert
SCHEMA_FILE
in
str
(
info
.
value
)
nonexistent_file
=
"this/file/does/not/exist"
files
=
[
invalid_file
,
valid_file
,
SCHEMA_FILE
,
nonexistent_file
]
with
pytest
.
raises
(
ValueError
)
as
info
:
data
=
ds
.
TFRecordDataset
(
files
,
SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
assert
"did not match any files"
in
str
(
info
.
value
)
assert
valid_file
not
in
str
(
info
.
value
)
assert
invalid_file
not
in
str
(
info
.
value
)
assert
SCHEMA_FILE
not
in
str
(
info
.
value
)
assert
nonexistent_file
in
str
(
info
.
value
)
if
__name__
==
'__main__'
:
test_case_tf_shape
()
test_case_tf_file
()
...
...
@@ -212,3 +238,4 @@ if __name__ == '__main__':
test_tf_record_schema
()
test_tf_record_shuffle
()
test_tf_shard_equal_rows
()
test_case_invalid_files
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录