Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2ecd5bdf
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ecd5bdf
编写于
6月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2239 [MD] convert csv to mindrecord
Merge pull request !2239 from liyong126/csv_to_mindrecord
上级
d9b8da14
7369950a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
322 addition
and
5 deletion
+322
-5
mindspore/mindrecord/__init__.py
mindspore/mindrecord/__init__.py
+2
-1
mindspore/mindrecord/tools/csv_to_mr.py
mindspore/mindrecord/tools/csv_to_mr.py
+168
-0
mindspore/mindrecord/tools/tfrecord_to_mr.py
mindspore/mindrecord/tools/tfrecord_to_mr.py
+2
-4
tests/ut/data/mindrecord/testCsv/data.csv
tests/ut/data/mindrecord/testCsv/data.csv
+7
-0
tests/ut/python/mindrecord/test_csv_to_mindrecord.py
tests/ut/python/mindrecord/test_csv_to_mindrecord.py
+143
-0
未找到文件。
mindspore/mindrecord/__init__.py
浏览文件 @
2ecd5bdf
...
...
@@ -29,10 +29,11 @@ from .common.exceptions import *
from
.shardutils
import
SUCCESS
,
FAILED
from
.tools.cifar10_to_mr
import
Cifar10ToMR
from
.tools.cifar100_to_mr
import
Cifar100ToMR
from
.tools.csv_to_mr
import
CsvToMR
from
.tools.imagenet_to_mr
import
ImageNetToMR
from
.tools.mnist_to_mr
import
MnistToMR
from
.tools.tfrecord_to_mr
import
TFRecordToMR
__all__
=
[
'FileWriter'
,
'FileReader'
,
'MindPage'
,
'Cifar10ToMR'
,
'Cifar100ToMR'
,
'ImageNetToMR'
,
'MnistToMR'
,
'TFRecordToMR'
,
'Cifar10ToMR'
,
'Cifar100ToMR'
,
'
CsvToMR'
,
'
ImageNetToMR'
,
'MnistToMR'
,
'TFRecordToMR'
,
'SUCCESS'
,
'FAILED'
]
mindspore/mindrecord/tools/csv_to_mr.py
0 → 100644
浏览文件 @
2ecd5bdf
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Csv format convert tool for MindRecord.
"""
from
importlib
import
import_module
import
os
from
mindspore
import
log
as
logger
from
..filewriter
import
FileWriter
from
..shardutils
import
check_filename
try
:
pd
=
import_module
(
"pandas"
)
except
ModuleNotFoundError
:
pd
=
None
__all__
=
[
'CsvToMR'
]
class
CsvToMR
:
"""
Class is for transformation from csv to MindRecord.
Args:
source (str): the file path of csv.
destination (str): the MindRecord file path to transform into.
columns_list(list[str], optional): List of columns to be read(default=None).
partition_number (int, optional): partition size (default=1).
Raises:
ValueError: If source, destination, partition_number is invalid.
RuntimeError: If columns_list is invalid.
"""
def
__init__
(
self
,
source
,
destination
,
columns_list
=
None
,
partition_number
=
1
):
if
not
pd
:
raise
Exception
(
"Module pandas is not found, please use pip install it."
)
if
isinstance
(
source
,
str
):
check_filename
(
source
)
self
.
source
=
source
else
:
raise
ValueError
(
"The parameter source must be str."
)
self
.
_check_columns
(
columns_list
,
"columns_list"
)
self
.
columns_list
=
columns_list
if
isinstance
(
destination
,
str
):
check_filename
(
destination
)
self
.
destination
=
destination
else
:
raise
ValueError
(
"The parameter destination must be str."
)
if
partition_number
is
not
None
:
if
not
isinstance
(
partition_number
,
int
):
raise
ValueError
(
"The parameter partition_number must be int"
)
self
.
partition_number
=
partition_number
else
:
raise
ValueError
(
"The parameter partition_number must be int"
)
self
.
writer
=
FileWriter
(
self
.
destination
,
self
.
partition_number
)
def
_check_columns
(
self
,
columns
,
columns_name
):
if
columns
:
if
isinstance
(
columns
,
list
):
for
col
in
columns
:
if
not
isinstance
(
col
,
str
):
raise
ValueError
(
"The parameter {} must be list of str."
.
format
(
columns_name
))
else
:
raise
ValueError
(
"The parameter {} must be list of str."
.
format
(
columns_name
))
def
_get_schema
(
self
,
df
):
"""
Construct schema from df columns
"""
if
self
.
columns_list
:
for
col
in
self
.
columns_list
:
if
col
not
in
df
.
columns
:
raise
RuntimeError
(
"The parameter columns_list is illegal, column {} does not exist."
.
format
(
col
))
else
:
self
.
columns_list
=
df
.
columns
schema
=
{}
for
col
in
self
.
columns_list
:
if
str
(
df
[
col
].
dtype
)
==
'int64'
:
schema
[
col
]
=
{
"type"
:
"int64"
}
elif
str
(
df
[
col
].
dtype
)
==
'float64'
:
schema
[
col
]
=
{
"type"
:
"float64"
}
elif
str
(
df
[
col
].
dtype
)
==
'bool'
:
schema
[
col
]
=
{
"type"
:
"int32"
}
else
:
schema
[
col
]
=
{
"type"
:
"string"
}
if
not
schema
:
raise
RuntimeError
(
"Failed to generate schema from csv file."
)
return
schema
def
_get_row_of_csv
(
self
,
df
):
"""Get row data from csv file."""
for
_
,
r
in
df
.
iterrows
():
row
=
{}
for
col
in
self
.
columns_list
:
if
str
(
df
[
col
].
dtype
)
==
'bool'
:
row
[
col
]
=
int
(
r
[
col
])
else
:
row
[
col
]
=
r
[
col
]
yield
row
def
transform
(
self
):
"""
Executes transformation from csv to MindRecord.
Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
"""
if
not
os
.
path
.
exists
(
self
.
source
):
raise
IOError
(
"Csv file {} do not exist."
.
format
(
self
.
source
))
pd
.
set_option
(
'display.max_columns'
,
None
)
df
=
pd
.
read_csv
(
self
.
source
)
csv_schema
=
self
.
_get_schema
(
df
)
logger
.
info
(
"transformed MindRecord schema is: {}"
.
format
(
csv_schema
))
# set the header size
self
.
writer
.
set_header_size
(
1
<<
24
)
# set the page size
self
.
writer
.
set_page_size
(
1
<<
26
)
# create the schema
self
.
writer
.
add_schema
(
csv_schema
,
"csv_schema"
)
# add the index
self
.
writer
.
add_index
(
list
(
self
.
columns_list
))
csv_iter
=
self
.
_get_row_of_csv
(
df
)
batch_size
=
256
transform_count
=
0
while
True
:
data_list
=
[]
try
:
for
_
in
range
(
batch_size
):
data_list
.
append
(
csv_iter
.
__next__
())
transform_count
+=
1
self
.
writer
.
write_raw_data
(
data_list
)
logger
.
info
(
"transformed {} record..."
.
format
(
transform_count
))
except
StopIteration
:
if
data_list
:
self
.
writer
.
write_raw_data
(
data_list
)
logger
.
info
(
"transformed {} record..."
.
format
(
transform_count
))
break
ret
=
self
.
writer
.
commit
()
return
ret
mindspore/mindrecord/tools/tfrecord_to_mr.py
浏览文件 @
2ecd5bdf
...
...
@@ -115,10 +115,8 @@ class TFRecordToMR:
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
bytes_fields (list): the bytes fields which are in feature_dict.
Rasies:
ValueError: the following condition will cause ValueError, 1) parameter TFRecord is not string, 2) parameter
MindRecord is not string, 3) feature_dict is not FixedLenFeature, 4) parameter bytes_field is not list(str)
or not in feature_dict.
Raises:
ValueError: If parameter is invalid.
Exception: when tensorflow module not found or version is not correct.
"""
def
__init__
(
self
,
source
,
destination
,
feature_dict
,
bytes_fields
=
None
):
...
...
tests/ut/data/mindrecord/testCsv/data.csv
0 → 100644
浏览文件 @
2ecd5bdf
Age,EmployNumber,Name,Sales,Over18
21, 10023,john, 123.45,True
41, 10223,tom, 12111,True
51, 10231,bob, 8779.0,True
86, 10053,alice, 7777,True
26, 1053,carol, 12345.8,False
tests/ut/python/mindrecord/test_csv_to_mindrecord.py
0 → 100644
浏览文件 @
2ecd5bdf
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test csv to mindrecord tool"""
import
os
from
importlib
import
import_module
import
pytest
from
mindspore
import
log
as
logger
from
mindspore.mindrecord
import
FileReader
from
mindspore.mindrecord
import
CsvToMR
try
:
pd
=
import_module
(
'pandas'
)
except
ModuleNotFoundError
:
pd
=
None
CSV_FILE
=
"../data/mindrecord/testCsv/data.csv"
MINDRECORD_FILE
=
"../data/mindrecord/testCsv/csv.mindrecord"
PARTITION_NUMBER
=
4
@
pytest
.
fixture
(
name
=
"remove_mindrecord_file"
)
def
fixture_remove
():
"""add/remove file"""
def
remove_one_file
(
x
):
if
os
.
path
.
exists
(
x
):
os
.
remove
(
x
)
def
remove_file
():
x
=
MINDRECORD_FILE
remove_one_file
(
x
)
x
=
MINDRECORD_FILE
+
".db"
remove_one_file
(
x
)
for
i
in
range
(
PARTITION_NUMBER
):
x
=
MINDRECORD_FILE
+
str
(
i
)
remove_one_file
(
x
)
x
=
MINDRECORD_FILE
+
str
(
i
)
+
".db"
remove_one_file
(
x
)
remove_file
()
yield
"yield_fixture_data"
remove_file
()
def
read
(
filename
,
columns
,
row_num
):
"""test file reade"""
if
not
pd
:
raise
Exception
(
"Module pandas is not found, please use pip install it."
)
df
=
pd
.
read_csv
(
CSV_FILE
)
count
=
0
reader
=
FileReader
(
filename
)
for
_
,
x
in
enumerate
(
reader
.
get_next
()):
for
col
in
columns
:
assert
x
[
col
]
==
df
[
col
].
iloc
[
count
]
assert
len
(
x
)
==
len
(
columns
)
count
=
count
+
1
if
count
==
1
:
logger
.
info
(
"data: {}"
.
format
(
x
))
assert
count
==
row_num
reader
.
close
()
def
test_csv_to_mindrecord
(
remove_mindrecord_file
):
"""test transform csv to mindrecord."""
csv_trans
=
CsvToMR
(
CSV_FILE
,
MINDRECORD_FILE
,
partition_number
=
PARTITION_NUMBER
)
csv_trans
.
transform
()
for
i
in
range
(
PARTITION_NUMBER
):
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
str
(
i
))
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
str
(
i
)
+
".db"
)
read
(
MINDRECORD_FILE
+
"0"
,
[
"Age"
,
"EmployNumber"
,
"Name"
,
"Sales"
,
"Over18"
],
5
)
def
test_csv_to_mindrecord_with_columns
(
remove_mindrecord_file
):
"""test transform csv to mindrecord."""
csv_trans
=
CsvToMR
(
CSV_FILE
,
MINDRECORD_FILE
,
columns_list
=
[
'Age'
,
'Sales'
],
partition_number
=
PARTITION_NUMBER
)
csv_trans
.
transform
()
for
i
in
range
(
PARTITION_NUMBER
):
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
str
(
i
))
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
str
(
i
)
+
".db"
)
read
(
MINDRECORD_FILE
+
"0"
,
[
"Age"
,
"Sales"
],
5
)
def
test_csv_to_mindrecord_with_no_exist_columns
(
remove_mindrecord_file
):
"""test transform csv to mindrecord."""
with
pytest
.
raises
(
Exception
,
match
=
"The parameter columns_list is illegal, column ssales does not exist."
):
csv_trans
=
CsvToMR
(
CSV_FILE
,
MINDRECORD_FILE
,
columns_list
=
[
'Age'
,
'ssales'
],
partition_number
=
PARTITION_NUMBER
)
csv_trans
.
transform
()
def
test_csv_partition_number_with_illegal_columns
(
remove_mindrecord_file
):
"""
test transform csv to mindrecord
"""
with
pytest
.
raises
(
Exception
,
match
=
"The parameter columns_list must be list of str."
):
csv_trans
=
CsvToMR
(
CSV_FILE
,
MINDRECORD_FILE
,
[
"Sales"
,
2
])
csv_trans
.
transform
()
def
test_csv_to_mindrecord_default_partition_number
(
remove_mindrecord_file
):
"""
test transform csv to mindrecord
when partition number is default.
"""
csv_trans
=
CsvToMR
(
CSV_FILE
,
MINDRECORD_FILE
)
csv_trans
.
transform
()
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
".db"
)
read
(
MINDRECORD_FILE
,
[
"Age"
,
"EmployNumber"
,
"Name"
,
"Sales"
,
"Over18"
],
5
)
def
test_csv_partition_number_0
(
remove_mindrecord_file
):
"""
test transform csv to mindrecord
when partition number is 0.
"""
with
pytest
.
raises
(
Exception
,
match
=
"Invalid parameter value"
):
csv_trans
=
CsvToMR
(
CSV_FILE
,
MINDRECORD_FILE
,
None
,
0
)
csv_trans
.
transform
()
def
test_csv_to_mindrecord_partition_number_none
(
remove_mindrecord_file
):
"""
test transform csv to mindrecord
when partition number is none.
"""
with
pytest
.
raises
(
Exception
,
match
=
"The parameter partition_number must be int"
):
csv_trans
=
CsvToMR
(
CSV_FILE
,
MINDRECORD_FILE
,
None
,
None
)
csv_trans
.
transform
()
def
test_csv_to_mindrecord_illegal_filename
(
remove_mindrecord_file
):
"""
test transform csv to mindrecord
when file name contains illegal character.
"""
filename
=
"not_*ok"
with
pytest
.
raises
(
Exception
,
match
=
"File name should not contains"
):
csv_trans
=
CsvToMR
(
CSV_FILE
,
filename
)
csv_trans
.
transform
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录