Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a9443635
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a9443635
编写于
4月 09, 2020
作者:
J
jonyguo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: mindpage enhance parameter check and search by filename failed
上级
aaa8d9ed
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
354 addition
and
65 deletion
+354
-65
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
+7
-0
mindspore/ccsrc/mindrecord/include/shard_index_generator.h
mindspore/ccsrc/mindrecord/include/shard_index_generator.h
+5
-5
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
+50
-28
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+51
-15
mindspore/ccsrc/mindrecord/io/shard_segment.cc
mindspore/ccsrc/mindrecord/io/shard_segment.cc
+13
-4
mindspore/mindrecord/mindpage.py
mindspore/mindrecord/mindpage.py
+13
-11
tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc
tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc
+3
-0
tests/ut/python/mindrecord/test_mindrecord_base.py
tests/ut/python/mindrecord/test_mindrecord_base.py
+145
-0
tests/ut/python/mindrecord/test_mindrecord_exception.py
tests/ut/python/mindrecord/test_mindrecord_exception.py
+67
-2
未找到文件。
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
浏览文件 @
a9443635
...
...
@@ -33,6 +33,7 @@
#include <map>
#include <random>
#include <set>
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
...
...
@@ -117,6 +118,12 @@ const char kPoint = '.';
// field type used by check schema validation
const
std
::
set
<
std
::
string
>
kFieldTypeSet
=
{
"bytes"
,
"string"
,
"int32"
,
"int64"
,
"float32"
,
"float64"
};
// can be searched field list
const
std
::
set
<
std
::
string
>
kScalarFieldTypeSet
=
{
"string"
,
"int32"
,
"int64"
,
"float32"
,
"float64"
};
// number field list
const
std
::
set
<
std
::
string
>
kNumberFieldTypeSet
=
{
"int32"
,
"int64"
,
"float32"
,
"float64"
};
/// \brief split a string using a character
/// \param[in] field target string
/// \param[in] separator a character for spliting
...
...
mindspore/ccsrc/mindrecord/include/shard_index_generator.h
浏览文件 @
a9443635
...
...
@@ -42,11 +42,11 @@ class ShardIndexGenerator {
~
ShardIndexGenerator
()
{}
/// \brief fetch value in json by field
path
/// \param[in] field
_path
/// \param[in]
schema
/// \return
the vector of value
st
atic
std
::
vector
<
std
::
string
>
GetField
(
const
std
::
string
&
field_path
,
json
schema
);
/// \brief fetch value in json by field
name
/// \param[in] field
/// \param[in]
input
/// \return
pair<MSRStatus, value>
st
d
::
pair
<
MSRStatus
,
std
::
string
>
GetValueByField
(
const
string
&
field
,
json
input
);
/// \brief fetch field type in schema n by field path
/// \param[in] field_path
...
...
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
浏览文件 @
a9443635
...
...
@@ -38,7 +38,7 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
MSRStatus
ShardIndexGenerator
::
Build
()
{
ShardHeader
header
=
ShardHeader
();
if
(
header
.
Build
(
file_path_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Build shard schema failed"
;
MS_LOG
(
ERROR
)
<<
"Build shard schema failed
.
"
;
return
FAILED
;
}
shard_header_
=
header
;
...
...
@@ -46,35 +46,49 @@ MSRStatus ShardIndexGenerator::Build() {
return
SUCCESS
;
}
std
::
vector
<
std
::
string
>
ShardIndexGenerator
::
GetField
(
const
string
&
field_path
,
json
schema
)
{
std
::
vector
<
std
::
string
>
field_name
=
StringSplit
(
field_path
,
kPoint
);
std
::
vector
<
std
::
string
>
res
;
if
(
schema
.
empty
())
{
res
.
emplace_back
(
"null"
);
return
res
;
std
::
pair
<
MSRStatus
,
std
::
string
>
ShardIndexGenerator
::
GetValueByField
(
const
string
&
field
,
json
input
)
{
if
(
field
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"The input field is None."
;
return
{
FAILED
,
""
};
}
for
(
uint64_t
i
=
0
;
i
<
field_name
.
size
();
i
++
)
{
// Check if field is part of an array of objects
auto
&
child
=
schema
.
at
(
field_name
[
i
]);
if
(
child
.
is_array
()
&&
!
child
.
empty
()
&&
child
[
0
].
is_object
())
{
schema
=
schema
[
field_name
[
i
]];
std
::
string
new_field_path
;
for
(
uint64_t
j
=
i
+
1
;
j
<
field_name
.
size
();
j
++
)
{
if
(
j
>
i
+
1
)
new_field_path
+=
'.'
;
new_field_path
+=
field_name
[
j
];
}
// Return multiple field data since multiple objects in array
for
(
auto
&
single_schema
:
schema
)
{
auto
child_res
=
GetField
(
new_field_path
,
single_schema
);
res
.
insert
(
res
.
end
(),
child_res
.
begin
(),
child_res
.
end
());
}
return
res
;
if
(
input
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"The input json is None."
;
return
{
FAILED
,
""
};
}
// parameter input does not contain the field
if
(
input
.
find
(
field
)
==
input
.
end
())
{
MS_LOG
(
ERROR
)
<<
"The field "
<<
field
<<
" is not found in parameter "
<<
input
;
return
{
FAILED
,
""
};
}
// schema does not contain the field
auto
schema
=
shard_header_
.
get_schemas
()[
0
]
->
GetSchema
()[
"schema"
];
if
(
schema
.
find
(
field
)
==
schema
.
end
())
{
MS_LOG
(
ERROR
)
<<
"The field "
<<
field
<<
" is not found in schema "
<<
schema
;
return
{
FAILED
,
""
};
}
// field should be scalar type
if
(
kScalarFieldTypeSet
.
find
(
schema
[
field
][
"type"
])
==
kScalarFieldTypeSet
.
end
())
{
MS_LOG
(
ERROR
)
<<
"The field "
<<
field
<<
" type is "
<<
schema
[
field
][
"type"
]
<<
", it is not retrievable"
;
return
{
FAILED
,
""
};
}
if
(
kNumberFieldTypeSet
.
find
(
schema
[
field
][
"type"
])
!=
kNumberFieldTypeSet
.
end
())
{
auto
schema_field_options
=
schema
[
field
];
if
(
schema_field_options
.
find
(
"shape"
)
==
schema_field_options
.
end
())
{
return
{
SUCCESS
,
input
[
field
].
dump
()};
}
else
{
// field with shape option
MS_LOG
(
ERROR
)
<<
"The field "
<<
field
<<
" shape is "
<<
schema
[
field
][
"shape"
]
<<
" which is not retrievable"
;
return
{
FAILED
,
""
};
}
schema
=
schema
.
at
(
field_name
[
i
]);
}
//
Return vector of one field data (not array of objects)
return
std
::
vector
<
std
::
string
>
{
schema
.
dump
()};
//
the field type is string in here
return
{
SUCCESS
,
input
[
field
].
get
<
std
::
string
>
()};
}
std
::
string
ShardIndexGenerator
::
TakeFieldType
(
const
string
&
field_path
,
json
schema
)
{
...
...
@@ -304,6 +318,7 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
const
auto
&
place_holder
=
std
::
get
<
0
>
(
field
);
const
auto
&
field_type
=
std
::
get
<
1
>
(
field
);
const
auto
&
field_value
=
std
::
get
<
2
>
(
field
);
int
index
=
sqlite3_bind_parameter_index
(
stmt
,
common
::
SafeCStr
(
place_holder
));
if
(
field_type
==
"INTEGER"
)
{
if
(
sqlite3_bind_int
(
stmt
,
index
,
std
::
stoi
(
field_value
))
!=
SQLITE_OK
)
{
...
...
@@ -463,17 +478,24 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s
if
(
field
.
first
>=
schema_detail
.
size
())
{
return
{
FAILED
,
{}};
}
auto
field_value
=
GetField
(
field
.
second
,
schema_detail
[
field
.
first
]);
auto
field_value
=
GetValueByField
(
field
.
second
,
schema_detail
[
field
.
first
]);
if
(
field_value
.
first
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Get value from json by field name failed"
;
return
{
FAILED
,
{}};
}
auto
result
=
shard_header_
.
GetSchemaByID
(
field
.
first
);
if
(
result
.
second
!=
SUCCESS
)
{
return
{
FAILED
,
{}};
}
std
::
string
field_type
=
ConvertJsonToSQL
(
TakeFieldType
(
field
.
second
,
result
.
first
->
GetSchema
()[
"schema"
]));
auto
ret
=
GenerateFieldName
(
field
);
if
(
ret
.
first
!=
SUCCESS
)
{
return
{
FAILED
,
{}};
}
fields
.
emplace_back
(
ret
.
second
,
field_type
,
field_value
[
0
]);
fields
.
emplace_back
(
ret
.
second
,
field_type
,
field_value
.
second
);
}
return
{
SUCCESS
,
std
::
move
(
fields
)};
}
...
...
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
a9443635
...
...
@@ -25,6 +25,15 @@ using mindspore::MsLogLevel::INFO;
namespace
mindspore
{
namespace
mindrecord
{
template
<
class
Type
>
// convert the string to exactly number type (int32_t/int64_t/float/double)
Type
StringToNum
(
const
std
::
string
&
str
)
{
std
::
istringstream
iss
(
str
);
Type
num
;
iss
>>
num
;
return
num
;
}
ShardReader
::
ShardReader
()
{
task_id_
=
0
;
deliver_id_
=
0
;
...
...
@@ -259,16 +268,25 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
}
column_values
[
shard_id
].
emplace_back
(
tmp
);
}
else
{
string
json_str
=
"{"
;
json
construct_json
;
for
(
unsigned
int
j
=
0
;
j
<
columns
.
size
();
++
j
)
{
// construct the string json "f1": value
json_str
=
json_str
+
"
\"
"
+
columns
[
j
]
+
"
\"
:"
+
labels
[
i
][
j
+
3
];
if
(
j
<
columns
.
size
()
-
1
)
{
json_str
+=
","
;
// construct json "f1": value
auto
schema
=
shard_header_
->
get_schemas
()[
0
]
->
GetSchema
()[
"schema"
];
// convert the string to base type by schema
if
(
schema
[
columns
[
j
]][
"type"
]
==
"int32"
)
{
construct_json
[
columns
[
j
]]
=
StringToNum
<
int32_t
>
(
labels
[
i
][
j
+
3
]);
}
else
if
(
schema
[
columns
[
j
]][
"type"
]
==
"int64"
)
{
construct_json
[
columns
[
j
]]
=
StringToNum
<
int64_t
>
(
labels
[
i
][
j
+
3
]);
}
else
if
(
schema
[
columns
[
j
]][
"type"
]
==
"float32"
)
{
construct_json
[
columns
[
j
]]
=
StringToNum
<
float
>
(
labels
[
i
][
j
+
3
]);
}
else
if
(
schema
[
columns
[
j
]][
"type"
]
==
"float64"
)
{
construct_json
[
columns
[
j
]]
=
StringToNum
<
double
>
(
labels
[
i
][
j
+
3
]);
}
else
{
construct_json
[
columns
[
j
]]
=
std
::
string
(
labels
[
i
][
j
+
3
]);
}
}
json_str
+=
"}"
;
column_values
[
shard_id
].
emplace_back
(
json
::
parse
(
json_str
));
column_values
[
shard_id
].
emplace_back
(
construct_json
);
}
}
...
...
@@ -402,7 +420,16 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
// whether use index search
if
(
!
criteria
.
first
.
empty
())
{
sql
+=
" AND "
+
criteria
.
first
+
"_"
+
std
::
to_string
(
column_schema_id_
[
criteria
.
first
])
+
" = "
+
criteria
.
second
;
auto
schema
=
shard_header_
->
get_schemas
()[
0
]
->
GetSchema
();
// not number field should add '' in sql
if
(
kNumberFieldTypeSet
.
find
(
schema
[
"schema"
][
criteria
.
first
][
"type"
])
!=
kNumberFieldTypeSet
.
end
())
{
sql
+=
" AND "
+
criteria
.
first
+
"_"
+
std
::
to_string
(
column_schema_id_
[
criteria
.
first
])
+
" = "
+
criteria
.
second
;
}
else
{
sql
+=
" AND "
+
criteria
.
first
+
"_"
+
std
::
to_string
(
column_schema_id_
[
criteria
.
first
])
+
" = '"
+
criteria
.
second
+
"'"
;
}
}
sql
+=
";"
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
image_offsets
;
...
...
@@ -603,16 +630,25 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
std
::
vector
<
json
>
ret
;
for
(
unsigned
int
i
=
0
;
i
<
labels
.
size
();
++
i
)
ret
.
emplace_back
(
json
{});
for
(
unsigned
int
i
=
0
;
i
<
labels
.
size
();
++
i
)
{
string
json_str
=
"{"
;
json
construct_json
;
for
(
unsigned
int
j
=
0
;
j
<
columns
.
size
();
++
j
)
{
// construct string json "f1": value
json_str
=
json_str
+
"
\"
"
+
columns
[
j
]
+
"
\"
:"
+
labels
[
i
][
j
];
if
(
j
<
columns
.
size
()
-
1
)
{
json_str
+=
","
;
// construct json "f1": value
auto
schema
=
shard_header_
->
get_schemas
()[
0
]
->
GetSchema
()[
"schema"
];
// convert the string to base type by schema
if
(
schema
[
columns
[
j
]][
"type"
]
==
"int32"
)
{
construct_json
[
columns
[
j
]]
=
StringToNum
<
int32_t
>
(
labels
[
i
][
j
]);
}
else
if
(
schema
[
columns
[
j
]][
"type"
]
==
"int64"
)
{
construct_json
[
columns
[
j
]]
=
StringToNum
<
int64_t
>
(
labels
[
i
][
j
]);
}
else
if
(
schema
[
columns
[
j
]][
"type"
]
==
"float32"
)
{
construct_json
[
columns
[
j
]]
=
StringToNum
<
float
>
(
labels
[
i
][
j
]);
}
else
if
(
schema
[
columns
[
j
]][
"type"
]
==
"float64"
)
{
construct_json
[
columns
[
j
]]
=
StringToNum
<
double
>
(
labels
[
i
][
j
]);
}
else
{
construct_json
[
columns
[
j
]]
=
std
::
string
(
labels
[
i
][
j
]);
}
}
json_str
+=
"}"
;
ret
[
i
]
=
json
::
parse
(
json_str
);
ret
[
i
]
=
construct_json
;
}
return
{
SUCCESS
,
ret
};
}
...
...
mindspore/ccsrc/mindrecord/io/shard_segment.cc
浏览文件 @
a9443635
...
...
@@ -311,14 +311,23 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS
MS_LOG
(
ERROR
)
<<
"Get category info"
;
return
{
FAILED
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
{}};
}
// category_name to category_id
int64_t
category_id
=
-
1
;
for
(
const
auto
&
categories
:
ret
.
second
)
{
if
(
std
::
get
<
1
>
(
categories
)
==
category_name
)
{
auto
result
=
ReadAllAtPageById
(
std
::
get
<
0
>
(
categories
),
page_no
,
n_rows_of_page
);
return
{
SUCCESS
,
result
.
second
};
std
::
string
categories_name
=
std
::
get
<
1
>
(
categories
);
if
(
categories_name
==
category_name
)
{
category_id
=
std
::
get
<
0
>
(
categories
);
break
;
}
}
return
{
SUCCESS
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
{}};
if
(
category_id
==
-
1
)
{
return
{
FAILED
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
{}};
}
return
ReadAllAtPageById
(
category_id
,
page_no
,
n_rows_of_page
);
}
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
pybind11
::
object
>>>
ShardSegment
::
ReadAtPageByIdPy
(
...
...
mindspore/mindrecord/mindpage.py
浏览文件 @
a9443635
...
...
@@ -133,15 +133,15 @@ class MindPage:
Raises:
ParamValueError: If any parameter is invalid.
MRMFetchDataError: If failed to
read by category id
.
MRMFetchDataError: If failed to
fetch data by category
.
MRMUnsupportedSchemaError: If schema is invalid.
"""
if
category_id
<
0
:
raise
ParamValueError
(
"Category id should be
greater than
0."
)
if
page
<
0
:
raise
ParamValueError
(
"Page should be
greater than
0."
)
if
n
um_row
<
0
:
raise
ParamValueError
(
"num_row should be greater than 0."
)
if
not
isinstance
(
category_id
,
int
)
or
category_id
<
0
:
raise
ParamValueError
(
"Category id should be
int and greater than or equal to
0."
)
if
not
isinstance
(
page
,
int
)
or
page
<
0
:
raise
ParamValueError
(
"Page should be
int and greater than or equal to
0."
)
if
n
ot
isinstance
(
num_row
,
int
)
or
num_row
<=
0
:
raise
ParamValueError
(
"num_row should be
int and
greater than 0."
)
return
self
.
_segment
.
read_at_page_by_id
(
category_id
,
page
,
num_row
)
def
read_at_page_by_name
(
self
,
category_name
,
page
,
num_row
):
...
...
@@ -157,8 +157,10 @@ class MindPage:
Returns:
str, read at page.
"""
if
page
<
0
:
raise
ParamValueError
(
"Page should be greater than 0."
)
if
num_row
<
0
:
raise
ParamValueError
(
"num_row should be greater than 0."
)
if
not
isinstance
(
category_name
,
str
):
raise
ParamValueError
(
"Category name should be str."
)
if
not
isinstance
(
page
,
int
)
or
page
<
0
:
raise
ParamValueError
(
"Page should be int and greater than or equal to 0."
)
if
not
isinstance
(
num_row
,
int
)
or
num_row
<=
0
:
raise
ParamValueError
(
"num_row should be int and greater than 0."
)
return
self
.
_segment
.
read_at_page_by_name
(
category_name
,
page
,
num_row
)
tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc
浏览文件 @
a9443635
...
...
@@ -53,6 +53,7 @@ class TestShardIndexGenerator : public UT::Common {
TestShardIndexGenerator
()
{}
};
/*
TEST_F(TestShardIndexGenerator, GetField) {
MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field");
...
...
@@ -82,6 +83,8 @@ TEST_F(TestShardIndexGenerator, GetField) {
}
}
}
*/
TEST_F
(
TestShardIndexGenerator
,
TakeFieldType
)
{
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test ShardSchema: take field Type"
);
...
...
tests/ut/python/mindrecord/test_mindrecord_base.py
浏览文件 @
a9443635
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""test mindrecord base"""
import
numpy
as
np
import
os
import
uuid
from
mindspore.mindrecord
import
FileWriter
,
FileReader
,
MindPage
,
SUCCESS
...
...
@@ -25,6 +26,105 @@ CV2_FILE_NAME = "./imagenet_loop.mindrecord"
CV3_FILE_NAME
=
"./imagenet_append.mindrecord"
NLP_FILE_NAME
=
"./aclImdb.mindrecord"
def
test_write_read_process
():
mindrecord_file_name
=
"test.mindrecord"
data
=
[{
"file_name"
:
"001.jpg"
,
"label"
:
43
,
"score"
:
0.8
,
"mask"
:
np
.
array
([
3
,
6
,
9
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
5.0
,
1.6
],
[
65.2
,
8.3
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes abc"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"002.jpg"
,
"label"
:
91
,
"score"
:
5.4
,
"mask"
:
np
.
array
([
1
,
4
,
7
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
5.1
,
9.1
],
[
2.0
,
65.4
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes def"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"003.jpg"
,
"label"
:
61
,
"score"
:
6.4
,
"mask"
:
np
.
array
([
7
,
6
,
3
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
0.0
,
5.6
],
[
3.0
,
16.3
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes ghi"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"004.jpg"
,
"label"
:
29
,
"score"
:
8.1
,
"mask"
:
np
.
array
([
2
,
8
,
0
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
5.9
,
7.2
],
[
4.0
,
89.0
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes jkl"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"005.jpg"
,
"label"
:
78
,
"score"
:
7.7
,
"mask"
:
np
.
array
([
3
,
1
,
2
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
0.6
,
8.1
],
[
5.3
,
49.3
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes mno"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"006.jpg"
,
"label"
:
37
,
"score"
:
9.4
,
"mask"
:
np
.
array
([
7
,
6
,
7
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
4.2
,
6.3
],
[
8.9
,
81.8
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes pqr"
,
encoding
=
'UTF-8'
)}
]
writer
=
FileWriter
(
mindrecord_file_name
)
schema
=
{
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
},
"score"
:
{
"type"
:
"float64"
},
"mask"
:
{
"type"
:
"int64"
,
"shape"
:
[
-
1
]},
"segments"
:
{
"type"
:
"float32"
,
"shape"
:
[
2
,
2
]},
"data"
:
{
"type"
:
"bytes"
}}
writer
.
add_schema
(
schema
,
"data is so cool"
)
writer
.
write_raw_data
(
data
)
writer
.
commit
()
reader
=
FileReader
(
mindrecord_file_name
)
count
=
0
for
index
,
x
in
enumerate
(
reader
.
get_next
()):
assert
len
(
x
)
==
6
for
field
in
x
:
if
isinstance
(
x
[
field
],
np
.
ndarray
):
assert
(
x
[
field
]
==
data
[
count
][
field
]).
all
()
else
:
assert
x
[
field
]
==
data
[
count
][
field
]
count
=
count
+
1
logger
.
info
(
"#item{}: {}"
.
format
(
index
,
x
))
assert
count
==
6
reader
.
close
()
os
.
remove
(
"{}"
.
format
(
mindrecord_file_name
))
os
.
remove
(
"{}.db"
.
format
(
mindrecord_file_name
))
def
test_write_read_process_with_define_index_field
():
mindrecord_file_name
=
"test.mindrecord"
data
=
[{
"file_name"
:
"001.jpg"
,
"label"
:
43
,
"score"
:
0.8
,
"mask"
:
np
.
array
([
3
,
6
,
9
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
5.0
,
1.6
],
[
65.2
,
8.3
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes abc"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"002.jpg"
,
"label"
:
91
,
"score"
:
5.4
,
"mask"
:
np
.
array
([
1
,
4
,
7
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
5.1
,
9.1
],
[
2.0
,
65.4
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes def"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"003.jpg"
,
"label"
:
61
,
"score"
:
6.4
,
"mask"
:
np
.
array
([
7
,
6
,
3
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
0.0
,
5.6
],
[
3.0
,
16.3
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes ghi"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"004.jpg"
,
"label"
:
29
,
"score"
:
8.1
,
"mask"
:
np
.
array
([
2
,
8
,
0
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
5.9
,
7.2
],
[
4.0
,
89.0
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes jkl"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"005.jpg"
,
"label"
:
78
,
"score"
:
7.7
,
"mask"
:
np
.
array
([
3
,
1
,
2
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
0.6
,
8.1
],
[
5.3
,
49.3
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes mno"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"006.jpg"
,
"label"
:
37
,
"score"
:
9.4
,
"mask"
:
np
.
array
([
7
,
6
,
7
],
dtype
=
np
.
int64
),
"segments"
:
np
.
array
([[
4.2
,
6.3
],
[
8.9
,
81.8
]],
dtype
=
np
.
float32
),
"data"
:
bytes
(
"image bytes pqr"
,
encoding
=
'UTF-8'
)}
]
writer
=
FileWriter
(
mindrecord_file_name
)
schema
=
{
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
},
"score"
:
{
"type"
:
"float64"
},
"mask"
:
{
"type"
:
"int64"
,
"shape"
:
[
-
1
]},
"segments"
:
{
"type"
:
"float32"
,
"shape"
:
[
2
,
2
]},
"data"
:
{
"type"
:
"bytes"
}}
writer
.
add_schema
(
schema
,
"data is so cool"
)
writer
.
add_index
([
"label"
])
writer
.
write_raw_data
(
data
)
writer
.
commit
()
reader
=
FileReader
(
mindrecord_file_name
)
count
=
0
for
index
,
x
in
enumerate
(
reader
.
get_next
()):
assert
len
(
x
)
==
6
for
field
in
x
:
if
isinstance
(
x
[
field
],
np
.
ndarray
):
assert
(
x
[
field
]
==
data
[
count
][
field
]).
all
()
else
:
assert
x
[
field
]
==
data
[
count
][
field
]
count
=
count
+
1
logger
.
info
(
"#item{}: {}"
.
format
(
index
,
x
))
assert
count
==
6
reader
.
close
()
os
.
remove
(
"{}"
.
format
(
mindrecord_file_name
))
os
.
remove
(
"{}.db"
.
format
(
mindrecord_file_name
))
def
test_cv_file_writer_tutorial
():
"""tutorial for cv dataset writer."""
writer
=
FileWriter
(
CV_FILE_NAME
,
FILES_NUM
)
...
...
@@ -137,6 +237,51 @@ def test_cv_page_reader_tutorial():
assert
len
(
row1
[
0
])
==
3
assert
row1
[
0
][
'label'
]
==
822
def
test_cv_page_reader_tutorial_by_file_name
():
"""tutorial for cv page reader."""
reader
=
MindPage
(
CV_FILE_NAME
+
"0"
)
fields
=
reader
.
get_category_fields
()
assert
fields
==
[
'file_name'
,
'label'
],
\
'failed on getting candidate category fields.'
ret
=
reader
.
set_category_field
(
"file_name"
)
assert
ret
==
SUCCESS
,
'failed on setting category field.'
info
=
reader
.
read_category_info
()
logger
.
info
(
"category info: {}"
.
format
(
info
))
row
=
reader
.
read_at_page_by_id
(
0
,
0
,
1
)
assert
len
(
row
)
==
1
assert
len
(
row
[
0
])
==
3
assert
row
[
0
][
'label'
]
==
490
row1
=
reader
.
read_at_page_by_name
(
"image_00007.jpg"
,
0
,
1
)
assert
len
(
row1
)
==
1
assert
len
(
row1
[
0
])
==
3
assert
row1
[
0
][
'label'
]
==
13
def
test_cv_page_reader_tutorial_new_api
():
"""tutorial for cv page reader."""
reader
=
MindPage
(
CV_FILE_NAME
+
"0"
)
fields
=
reader
.
candidate_fields
assert
fields
==
[
'file_name'
,
'label'
],
\
'failed on getting candidate category fields.'
reader
.
category_field
=
"file_name"
info
=
reader
.
read_category_info
()
logger
.
info
(
"category info: {}"
.
format
(
info
))
row
=
reader
.
read_at_page_by_id
(
0
,
0
,
1
)
assert
len
(
row
)
==
1
assert
len
(
row
[
0
])
==
3
assert
row
[
0
][
'label'
]
==
490
row1
=
reader
.
read_at_page_by_name
(
"image_00007.jpg"
,
0
,
1
)
assert
len
(
row1
)
==
1
assert
len
(
row1
[
0
])
==
3
assert
row1
[
0
][
'label'
]
==
13
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
FILES_NUM
)]
for
x
in
paths
:
...
...
tests/ut/python/mindrecord/test_mindrecord_exception.py
浏览文件 @
a9443635
...
...
@@ -15,8 +15,9 @@
"""test mindrecord exception"""
import
os
import
pytest
from
mindspore.mindrecord
import
FileWriter
,
FileReader
,
MindPage
from
mindspore.mindrecord
import
MRMOpenError
,
MRMGenerateIndexError
,
ParamValueError
,
MRMGetMetaError
from
mindspore.mindrecord
import
FileWriter
,
FileReader
,
MindPage
,
SUCCESS
from
mindspore.mindrecord
import
MRMOpenError
,
MRMGenerateIndexError
,
ParamValueError
,
MRMGetMetaError
,
\
MRMFetchDataError
from
mindspore
import
log
as
logger
from
utils
import
get_data
...
...
@@ -286,3 +287,67 @@ def test_add_index_without_add_schema():
fw
=
FileWriter
(
CV_FILE_NAME
)
fw
.
add_index
([
"label"
])
assert
'Failed to get meta info'
in
str
(
err
.
value
)
def
test_mindpage_pageno_pagesize_not_int
():
"""test page reader when some partition does not exist."""
create_cv_mindrecord
(
4
)
reader
=
MindPage
(
CV_FILE_NAME
+
"0"
)
fields
=
reader
.
get_category_fields
()
assert
fields
==
[
'file_name'
,
'label'
],
\
'failed on getting candidate category fields.'
ret
=
reader
.
set_category_field
(
"label"
)
assert
ret
==
SUCCESS
,
'failed on setting category field.'
info
=
reader
.
read_category_info
()
logger
.
info
(
"category info: {}"
.
format
(
info
))
with
pytest
.
raises
(
ParamValueError
)
as
err
:
reader
.
read_at_page_by_id
(
0
,
"0"
,
1
)
with
pytest
.
raises
(
ParamValueError
)
as
err
:
reader
.
read_at_page_by_id
(
0
,
0
,
"b"
)
with
pytest
.
raises
(
ParamValueError
)
as
err
:
reader
.
read_at_page_by_name
(
"822"
,
"e"
,
1
)
with
pytest
.
raises
(
ParamValueError
)
as
err
:
reader
.
read_at_page_by_name
(
"822"
,
0
,
"qwer"
)
with
pytest
.
raises
(
MRMFetchDataError
,
match
=
"Failed to fetch data by category."
):
reader
.
read_at_page_by_id
(
99999
,
0
,
1
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
FILES_NUM
)]
for
x
in
paths
:
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_mindpage_filename_not_exist
():
"""test page reader when some partition does not exist."""
create_cv_mindrecord
(
4
)
reader
=
MindPage
(
CV_FILE_NAME
+
"0"
)
fields
=
reader
.
get_category_fields
()
assert
fields
==
[
'file_name'
,
'label'
],
\
'failed on getting candidate category fields.'
ret
=
reader
.
set_category_field
(
"file_name"
)
assert
ret
==
SUCCESS
,
'failed on setting category field.'
info
=
reader
.
read_category_info
()
logger
.
info
(
"category info: {}"
.
format
(
info
))
with
pytest
.
raises
(
MRMFetchDataError
)
as
err
:
reader
.
read_at_page_by_id
(
9999
,
0
,
1
)
with
pytest
.
raises
(
MRMFetchDataError
)
as
err
:
reader
.
read_at_page_by_name
(
"abc.jpg"
,
0
,
1
)
with
pytest
.
raises
(
ParamValueError
)
as
err
:
reader
.
read_at_page_by_name
(
1
,
0
,
1
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
FILES_NUM
)]
for
x
in
paths
:
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录