Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
db80f4ff
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
db80f4ff
编写于
4月 20, 2020
作者:
Q
qianlong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
The num_samples and numRows in schema for TFRecordDataset are conflict
上级
46acf238
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
119 addition
and
4 deletion
+119
-4
mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc
.../ccsrc/dataset/engine/datasetops/source/storage_client.cc
+5
-1
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
...re/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
+3
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+9
-3
tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json
...t/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json
+45
-0
tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json
...ta/dataset/test_tf_file_3_images/datasetNoRowsSchema.json
+15
-0
tests/ut/python/dataset/test_storage.py
tests/ut/python/dataset/test_storage.py
+12
-0
tests/ut/python/dataset/test_tfreader_op.py
tests/ut/python/dataset/test_tfreader_op.py
+30
-0
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc
浏览文件 @
db80f4ff
...
@@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const {
...
@@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const {
std
::
ifstream
in
(
schemaFile
);
std
::
ifstream
in
(
schemaFile
);
nlohmann
::
json
js
;
nlohmann
::
json
js
;
in
>>
js
;
in
>>
js
;
num_rows
=
js
.
value
(
"numRows"
,
0
);
if
(
js
.
find
(
"numRows"
)
==
js
.
end
())
{
num_rows
=
MAX_INTEGER_INT32
;
}
else
{
num_rows
=
js
.
value
(
"numRows"
,
0
);
}
if
(
num_rows
==
0
)
{
if
(
num_rows
==
0
)
{
std
::
string
err_msg
=
std
::
string
err_msg
=
"Storage client has not properly done dataset "
"Storage client has not properly done dataset "
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
浏览文件 @
db80f4ff
...
@@ -163,6 +163,9 @@ Status TFReaderOp::Init() {
...
@@ -163,6 +163,9 @@ Status TFReaderOp::Init() {
if
(
total_rows_
==
0
)
{
if
(
total_rows_
==
0
)
{
total_rows_
=
data_schema_
->
num_rows
();
total_rows_
=
data_schema_
->
num_rows
();
}
}
if
(
total_rows_
<
0
)
{
RETURN_STATUS_UNEXPECTED
(
"The num_sample or numRows for TFRecordDataset should be greater than 0"
);
}
// Build the index with our files such that each file corresponds to a key id.
// Build the index with our files such that each file corresponds to a key id.
RETURN_IF_NOT_OK
(
filename_index_
->
insert
(
dataset_files_list_
));
RETURN_IF_NOT_OK
(
filename_index_
->
insert
(
dataset_files_list_
));
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
db80f4ff
...
@@ -1455,7 +1455,7 @@ class StorageDataset(SourceDataset):
...
@@ -1455,7 +1455,7 @@ class StorageDataset(SourceDataset):
Args:
Args:
dataset_files (list[str]): List of files to be read.
dataset_files (list[str]): List of files to be read.
schema (str): Path to the json schema file.
schema (str): Path to the json schema file.
If numRows(parsed from schema) is not exist, read the full dataset.
distribution (str, optional): Path of distribution config file (default="").
distribution (str, optional): Path of distribution config file (default="").
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
...
@@ -2193,7 +2193,10 @@ class TFRecordDataset(SourceDataset):
...
@@ -2193,7 +2193,10 @@ class TFRecordDataset(SourceDataset):
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the meta data from the TFData file is considered the schema.
If the schema is not provided, the meta data from the TFData file is considered the schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_samples (int, optional): number of samples(rows) to read (default=None).
If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset;
If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
num_parallel_workers (int, optional): number of workers to read the data
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
...
@@ -2711,10 +2714,10 @@ class Schema:
...
@@ -2711,10 +2714,10 @@ class Schema:
"""
"""
def
__init__
(
self
,
schema_file
=
None
):
def
__init__
(
self
,
schema_file
=
None
):
self
.
num_rows
=
None
if
schema_file
is
None
:
if
schema_file
is
None
:
self
.
columns
=
[]
self
.
columns
=
[]
self
.
dataset_type
=
''
self
.
dataset_type
=
''
self
.
num_rows
=
0
else
:
else
:
if
not
os
.
path
.
isfile
(
schema_file
)
or
not
os
.
access
(
schema_file
,
os
.
R_OK
):
if
not
os
.
path
.
isfile
(
schema_file
)
or
not
os
.
access
(
schema_file
,
os
.
R_OK
):
raise
ValueError
(
"The file %s does not exist or permission denied!"
%
schema_file
)
raise
ValueError
(
"The file %s does not exist or permission denied!"
%
schema_file
)
...
@@ -2859,6 +2862,9 @@ class Schema:
...
@@ -2859,6 +2862,9 @@ class Schema:
raise
RuntimeError
(
"DatasetType field is missing."
)
raise
RuntimeError
(
"DatasetType field is missing."
)
if
self
.
columns
is
None
:
if
self
.
columns
is
None
:
raise
RuntimeError
(
"Columns are missing."
)
raise
RuntimeError
(
"Columns are missing."
)
if
self
.
num_rows
is
not
None
:
if
not
isinstance
(
self
.
num_rows
,
int
)
or
self
.
num_rows
<=
0
:
raise
ValueError
(
"numRows must be greater than 0"
)
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
to_json
()
return
self
.
to_json
()
...
...
tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json
0 → 100644
浏览文件 @
db80f4ff
{
"datasetType"
:
"TF"
,
"columns"
:
{
"col_sint16"
:
{
"type"
:
"int16"
,
"rank"
:
1
,
"shape"
:
[
1
]
},
"col_sint32"
:
{
"type"
:
"int32"
,
"rank"
:
1
,
"shape"
:
[
1
]
},
"col_sint64"
:
{
"type"
:
"int64"
,
"rank"
:
1
,
"shape"
:
[
1
]
},
"col_float"
:
{
"type"
:
"float32"
,
"rank"
:
1
,
"shape"
:
[
1
]
},
"col_1d"
:
{
"type"
:
"int64"
,
"rank"
:
1
,
"shape"
:
[
2
]
},
"col_2d"
:
{
"type"
:
"int64"
,
"rank"
:
2
,
"shape"
:
[
2
,
2
]
},
"col_3d"
:
{
"type"
:
"int64"
,
"rank"
:
3
,
"shape"
:
[
2
,
2
,
2
]
},
"col_binary"
:
{
"type"
:
"uint8"
,
"rank"
:
1
,
"shape"
:
[
1
]
}
}
}
tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json
0 → 100644
浏览文件 @
db80f4ff
{
"datasetType"
:
"TF"
,
"columns"
:
{
"image"
:
{
"type"
:
"uint8"
,
"rank"
:
1
,
"t_impl"
:
"cvmat"
},
"label"
:
{
"type"
:
"uint64"
,
"rank"
:
1
,
"t_impl"
:
"flex"
}
}
}
tests/ut/python/dataset/test_storage.py
浏览文件 @
db80f4ff
...
@@ -37,3 +37,15 @@ def test_case_storage():
...
@@ -37,3 +37,15 @@ def test_case_storage():
filename
=
"storage_result.npz"
filename
=
"storage_result.npz"
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
save_and_check
(
data1
,
parameters
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
def
test_case_no_rows
():
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json"
dataset
=
ds
.
StorageDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
])
assert
dataset
.
get_dataset_size
()
==
3
count
=
0
for
data
in
dataset
.
create_tuple_iterator
():
count
+=
1
assert
count
==
3
tests/ut/python/dataset/test_tfreader_op.py
浏览文件 @
db80f4ff
...
@@ -37,6 +37,36 @@ def test_case_tf_shape():
...
@@ -37,6 +37,36 @@ def test_case_tf_shape():
assert
(
len
(
output_shape
[
-
1
])
==
1
)
assert
(
len
(
output_shape
[
-
1
])
==
1
)
def
test_case_tf_read_all_dataset
():
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
assert
ds1
.
get_dataset_size
()
==
12
count
=
0
for
data
in
ds1
.
create_tuple_iterator
():
count
+=
1
assert
count
==
12
def
test_case_num_samples
():
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
,
num_samples
=
8
)
assert
ds1
.
get_dataset_size
()
==
8
count
=
0
for
data
in
ds1
.
create_dict_iterator
():
count
+=
1
assert
count
==
8
def
test_case_num_samples2
():
schema_file
=
"../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1
=
ds
.
TFRecordDataset
(
FILES
,
schema_file
)
assert
ds1
.
get_dataset_size
()
==
7
count
=
0
for
data
in
ds1
.
create_dict_iterator
():
count
+=
1
assert
count
==
7
def
test_case_tf_shape_2
():
def
test_case_tf_shape_2
():
ds1
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
)
ds1
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
)
ds1
=
ds1
.
batch
(
2
)
ds1
=
ds1
.
batch
(
2
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录