Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d4d236bc
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d4d236bc
编写于
5月 06, 2020
作者:
J
jonyguo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: use MindDataset by column_names get data error in some situation
上级
d004ef22
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
1149 addition
and
10 deletion
+1149
-10
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
...e/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
+14
-4
mindspore/ccsrc/mindrecord/include/shard_reader.h
mindspore/ccsrc/mindrecord/include/shard_reader.h
+4
-0
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+68
-2
mindspore/mindrecord/shardutils.py
mindspore/mindrecord/shardutils.py
+16
-4
tests/ut/python/dataset/test_minddataset.py
tests/ut/python/dataset/test_minddataset.py
+594
-0
tests/ut/python/mindrecord/test_mindrecord_base.py
tests/ut/python/mindrecord/test_mindrecord_base.py
+453
-0
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
d4d236bc
...
...
@@ -165,12 +165,22 @@ Status MindRecordOp::Init() {
Status
MindRecordOp
::
SetColumnsBlob
()
{
columns_blob_
=
shard_reader_
->
get_blob_fields
().
second
;
// get the exactly blob fields by columns_to_load_
std
::
vector
<
std
::
string
>
columns_blob_exact
;
for
(
auto
&
blob_field
:
columns_blob_
)
{
for
(
auto
&
column
:
columns_to_load_
)
{
if
(
column
.
compare
(
blob_field
)
==
0
)
{
columns_blob_exact
.
push_back
(
blob_field
);
break
;
}
}
}
columns_blob_index_
=
std
::
vector
<
int32_t
>
(
columns_to_load_
.
size
(),
-
1
);
int32_t
iBlob
=
0
;
for
(
uint32_t
i
=
0
;
i
<
columns_blob_
.
size
();
++
i
)
{
if
(
column_name_mapping_
.
count
(
columns_blob_
[
i
]))
{
columns_blob_index_
[
column_name_mapping_
[
columns_blob_
[
i
]]]
=
iBlob
++
;
}
for
(
auto
&
blob_exact
:
columns_blob_exact
)
{
columns_blob_index_
[
column_name_mapping_
[
blob_exact
]]
=
iBlob
++
;
}
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/mindrecord/include/shard_reader.h
浏览文件 @
d4d236bc
...
...
@@ -294,6 +294,10 @@ class ShardReader {
/// \brief get number of classes
int64_t
GetNumClasses
(
const
std
::
string
&
file_path
,
const
std
::
string
&
category_field
);
/// \brief get exactly blob fields data by indices
std
::
vector
<
uint8_t
>
ExtractBlobFieldBySelectColumns
(
std
::
vector
<
uint8_t
>
&
blob_fields_bytes
,
std
::
vector
<
uint32_t
>
&
ordered_selected_columns_index
);
protected:
uint64_t
header_size_
;
// header size
uint64_t
page_size_
;
// page size
...
...
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
d4d236bc
...
...
@@ -790,6 +790,8 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
n_consumer
=
kMinConsumerCount
;
}
CheckNlp
();
// dead code
if
(
nlp_
)
{
selected_columns_
=
selected_columns
;
}
else
{
...
...
@@ -801,6 +803,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
}
}
}
selected_columns_
=
selected_columns
;
if
(
CheckColumnList
(
selected_columns_
)
==
FAILED
)
{
MS_LOG
(
ERROR
)
<<
"Illegal column list"
;
...
...
@@ -1060,6 +1063,36 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
return
SUCCESS
;
}
std
::
vector
<
uint8_t
>
ShardReader
::
ExtractBlobFieldBySelectColumns
(
std
::
vector
<
uint8_t
>
&
blob_fields_bytes
,
std
::
vector
<
uint32_t
>
&
ordered_selected_columns_index
)
{
std
::
vector
<
uint8_t
>
exactly_blob_fields_bytes
;
auto
uint64_from_bytes
=
[
&
](
int64_t
pos
)
{
uint64_t
result
=
0
;
for
(
uint64_t
n
=
0
;
n
<
kInt64Len
;
n
++
)
{
result
=
(
result
<<
8
)
+
blob_fields_bytes
[
pos
+
n
];
}
return
result
;
};
// get the exactly blob fields
uint32_t
current_index
=
0
;
uint64_t
current_offset
=
0
;
uint64_t
data_len
=
uint64_from_bytes
(
current_offset
);
while
(
current_offset
<
blob_fields_bytes
.
size
())
{
if
(
std
::
any_of
(
ordered_selected_columns_index
.
begin
(),
ordered_selected_columns_index
.
end
(),
[
&
current_index
](
uint32_t
&
index
)
{
return
index
==
current_index
;
}))
{
exactly_blob_fields_bytes
.
insert
(
exactly_blob_fields_bytes
.
end
(),
blob_fields_bytes
.
begin
()
+
current_offset
,
blob_fields_bytes
.
begin
()
+
current_offset
+
kInt64Len
+
data_len
);
}
current_index
++
;
current_offset
+=
kInt64Len
+
data_len
;
data_len
=
uint64_from_bytes
(
current_offset
);
}
return
exactly_blob_fields_bytes
;
}
TASK_RETURN_CONTENT
ShardReader
::
ConsumerOneTask
(
int
task_id
,
uint32_t
consumer_id
)
{
// All tasks are done
if
(
task_id
>=
static_cast
<
int
>
(
tasks_
.
Size
()))
{
...
...
@@ -1077,6 +1110,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return
std
::
make_pair
(
FAILED
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
());
}
const
std
::
shared_ptr
<
Page
>
&
page
=
ret
.
second
;
// Pack image list
std
::
vector
<
uint8_t
>
images
(
addr
[
1
]
-
addr
[
0
]);
auto
file_offset
=
header_size_
+
page_size_
*
(
page
->
get_page_id
())
+
addr
[
0
];
...
...
@@ -1096,10 +1130,42 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return
std
::
make_pair
(
FAILED
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
());
}
// extract the exactly blob bytes by selected columns
std
::
vector
<
uint8_t
>
images_with_exact_columns
;
if
(
selected_columns_
.
size
()
==
0
)
{
images_with_exact_columns
=
images
;
}
else
{
auto
blob_fields
=
get_blob_fields
();
std
::
vector
<
uint32_t
>
ordered_selected_columns_index
;
uint32_t
index
=
0
;
for
(
auto
&
blob_field
:
blob_fields
.
second
)
{
for
(
auto
&
field
:
selected_columns_
)
{
if
(
field
.
compare
(
blob_field
)
==
0
)
{
ordered_selected_columns_index
.
push_back
(
index
);
break
;
}
}
index
++
;
}
if
(
ordered_selected_columns_index
.
size
()
!=
0
)
{
// extract the images
if
(
blob_fields
.
second
.
size
()
==
1
)
{
if
(
ordered_selected_columns_index
.
size
()
==
1
)
{
images_with_exact_columns
=
images
;
}
}
else
{
images_with_exact_columns
=
ExtractBlobFieldBySelectColumns
(
images
,
ordered_selected_columns_index
);
}
}
}
// Deliver batch data to output map
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
batch
;
if
(
nlp_
)
{
json
blob_fields
=
json
::
from_msgpack
(
images
);
// dead code
json
blob_fields
=
json
::
from_msgpack
(
images_with_exact_columns
);
json
merge
;
if
(
selected_columns_
.
size
()
>
0
)
{
...
...
@@ -1117,7 +1183,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
}
batch
.
emplace_back
(
std
::
vector
<
uint8_t
>
{},
std
::
move
(
merge
));
}
else
{
batch
.
emplace_back
(
std
::
move
(
images
),
std
::
move
(
std
::
get
<
2
>
(
task
)));
batch
.
emplace_back
(
std
::
move
(
images
_with_exact_columns
),
std
::
move
(
std
::
get
<
2
>
(
task
)));
}
return
std
::
make_pair
(
SUCCESS
,
std
::
move
(
batch
));
}
...
...
mindspore/mindrecord/shardutils.py
浏览文件 @
d4d236bc
...
...
@@ -92,15 +92,25 @@ def populate_data(raw, blob, columns, blob_fields, schema):
if
raw
:
# remove dummy fileds
raw
=
{
k
:
v
for
k
,
v
in
raw
.
items
()
if
k
in
schema
}
else
:
raw
=
{}
if
not
blob_fields
:
return
raw
# Get the order preserving sequence of columns in blob
ordered_columns
=
[]
if
columns
:
for
blob_field
in
blob_fields
:
if
blob_field
in
columns
:
ordered_columns
.
append
(
blob_field
)
else
:
ordered_columns
=
blob_fields
blob_bytes
=
bytes
(
blob
)
def
_render_raw
(
field
,
blob_data
):
data_type
=
schema
[
field
][
'type'
]
data_shape
=
schema
[
field
][
'shape'
]
if
'shape'
in
schema
[
field
]
else
[]
if
columns
and
field
not
in
columns
:
return
if
data_shape
:
try
:
raw
[
field
]
=
np
.
reshape
(
np
.
frombuffer
(
blob_data
,
dtype
=
data_type
),
data_shape
)
...
...
@@ -110,7 +120,9 @@ def populate_data(raw, blob, columns, blob_fields, schema):
raw
[
field
]
=
blob_data
if
len
(
blob_fields
)
==
1
:
_render_raw
(
blob_fields
[
0
],
blob_bytes
)
if
len
(
ordered_columns
)
==
1
:
_render_raw
(
blob_fields
[
0
],
blob_bytes
)
return
raw
return
raw
def
_int_from_bytes
(
xbytes
:
bytes
)
->
int
:
...
...
@@ -125,6 +137,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
start
+=
8
return
blob_bytes
[
start
:
start
+
n_bytes
]
for
i
,
blob_field
in
enumerate
(
blob_field
s
):
for
i
,
blob_field
in
enumerate
(
ordered_column
s
):
_render_raw
(
blob_field
,
_blob_at_position
(
i
))
return
raw
tests/ut/python/dataset/test_minddataset.py
浏览文件 @
d4d236bc
此差异已折叠。
点击以展开。
tests/ut/python/mindrecord/test_mindrecord_base.py
浏览文件 @
d4d236bc
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录