Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
cf352d19
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cf352d19
编写于
5月 08, 2020
作者:
J
jonyguo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format func name for mindrecord
上级
93429aba
变更
35
隐藏空白更改
内联
并排
Showing
35 changed file
with
308 addition
and
308 deletion
+308
-308
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
...e/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
+4
-4
mindspore/ccsrc/mindrecord/common/shard_pybind.cc
mindspore/ccsrc/mindrecord/common/shard_pybind.cc
+14
-14
mindspore/ccsrc/mindrecord/include/shard_category.h
mindspore/ccsrc/mindrecord/include/shard_category.h
+2
-2
mindspore/ccsrc/mindrecord/include/shard_header.h
mindspore/ccsrc/mindrecord/include/shard_header.h
+15
-15
mindspore/ccsrc/mindrecord/include/shard_index.h
mindspore/ccsrc/mindrecord/include/shard_index.h
+1
-1
mindspore/ccsrc/mindrecord/include/shard_operator.h
mindspore/ccsrc/mindrecord/include/shard_operator.h
+6
-6
mindspore/ccsrc/mindrecord/include/shard_page.h
mindspore/ccsrc/mindrecord/include/shard_page.h
+12
-12
mindspore/ccsrc/mindrecord/include/shard_pk_sample.h
mindspore/ccsrc/mindrecord/include/shard_pk_sample.h
+1
-1
mindspore/ccsrc/mindrecord/include/shard_reader.h
mindspore/ccsrc/mindrecord/include/shard_reader.h
+6
-6
mindspore/ccsrc/mindrecord/include/shard_sample.h
mindspore/ccsrc/mindrecord/include/shard_sample.h
+3
-3
mindspore/ccsrc/mindrecord/include/shard_schema.h
mindspore/ccsrc/mindrecord/include/shard_schema.h
+4
-4
mindspore/ccsrc/mindrecord/include/shard_segment.h
mindspore/ccsrc/mindrecord/include/shard_segment.h
+1
-1
mindspore/ccsrc/mindrecord/include/shard_shuffle.h
mindspore/ccsrc/mindrecord/include/shard_shuffle.h
+1
-1
mindspore/ccsrc/mindrecord/include/shard_statistics.h
mindspore/ccsrc/mindrecord/include/shard_statistics.h
+4
-4
mindspore/ccsrc/mindrecord/include/shard_task.h
mindspore/ccsrc/mindrecord/include/shard_task.h
+2
-2
mindspore/ccsrc/mindrecord/include/shard_writer.h
mindspore/ccsrc/mindrecord/include/shard_writer.h
+2
-2
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
+22
-22
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+40
-40
mindspore/ccsrc/mindrecord/io/shard_segment.cc
mindspore/ccsrc/mindrecord/io/shard_segment.cc
+8
-8
mindspore/ccsrc/mindrecord/io/shard_writer.cc
mindspore/ccsrc/mindrecord/io/shard_writer.cc
+37
-37
mindspore/ccsrc/mindrecord/meta/shard_category.cc
mindspore/ccsrc/mindrecord/meta/shard_category.cc
+1
-1
mindspore/ccsrc/mindrecord/meta/shard_header.cc
mindspore/ccsrc/mindrecord/meta/shard_header.cc
+25
-25
mindspore/ccsrc/mindrecord/meta/shard_index.cc
mindspore/ccsrc/mindrecord/meta/shard_index.cc
+1
-1
mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc
mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc
+1
-1
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
+6
-6
mindspore/ccsrc/mindrecord/meta/shard_schema.cc
mindspore/ccsrc/mindrecord/meta/shard_schema.cc
+5
-5
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
+1
-1
mindspore/ccsrc/mindrecord/meta/shard_statistics.cc
mindspore/ccsrc/mindrecord/meta/shard_statistics.cc
+6
-6
mindspore/ccsrc/mindrecord/meta/shard_task.cc
mindspore/ccsrc/mindrecord/meta/shard_task.cc
+4
-4
tests/ut/cpp/mindrecord/ut_shard.cc
tests/ut/cpp/mindrecord/ut_shard.cc
+6
-6
tests/ut/cpp/mindrecord/ut_shard_header_test.cc
tests/ut/cpp/mindrecord/ut_shard_header_test.cc
+11
-11
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
+1
-1
tests/ut/cpp/mindrecord/ut_shard_page_test.cc
tests/ut/cpp/mindrecord/ut_shard_page_test.cc
+39
-39
tests/ut/cpp/mindrecord/ut_shard_schema_test.cc
tests/ut/cpp/mindrecord/ut_shard_schema_test.cc
+6
-6
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
+10
-10
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
cf352d19
...
...
@@ -108,7 +108,7 @@ Status MindRecordOp::Init() {
data_schema_
=
std
::
make_unique
<
DataSchema
>
();
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
schema_vec
=
shard_reader_
->
get_shard_header
()
->
get_s
chemas
();
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
schema_vec
=
shard_reader_
->
GetShardHeader
()
->
GetS
chemas
();
// check whether schema exists, if so use the first one
CHECK_FAIL_RETURN_UNEXPECTED
(
!
schema_vec
.
empty
(),
"No schema found"
);
mindrecord
::
json
mr_schema
=
schema_vec
[
0
]
->
GetSchema
()[
"schema"
];
...
...
@@ -155,7 +155,7 @@ Status MindRecordOp::Init() {
column_name_mapping_
[
columns_to_load_
[
i
]]
=
i
;
}
num_rows_
=
shard_reader_
->
get_num_r
ows
();
num_rows_
=
shard_reader_
->
GetNumR
ows
();
// Compute how many buffers we would need to accomplish rowsPerBuffer
buffers_needed_
=
(
num_rows_
+
rows_per_buffer_
-
1
)
/
rows_per_buffer_
;
RETURN_IF_NOT_OK
(
SetColumnsBlob
());
...
...
@@ -164,7 +164,7 @@ Status MindRecordOp::Init() {
}
Status
MindRecordOp
::
SetColumnsBlob
()
{
columns_blob_
=
shard_reader_
->
get_blob_f
ields
().
second
;
columns_blob_
=
shard_reader_
->
GetBlobF
ields
().
second
;
// get the exactly blob fields by columns_to_load_
std
::
vector
<
std
::
string
>
columns_blob_exact
;
...
...
@@ -600,7 +600,7 @@ Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) {
// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work
Status
MindRecordOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
LaunchThreadAndInitOp
());
num_rows_
=
shard_reader_
->
get_num_r
ows
();
num_rows_
=
shard_reader_
->
GetNumR
ows
();
buffers_needed_
=
num_rows_
/
rows_per_buffer_
;
if
(
num_rows_
%
rows_per_buffer_
!=
0
)
{
...
...
mindspore/ccsrc/mindrecord/common/shard_pybind.cc
浏览文件 @
cf352d19
...
...
@@ -39,18 +39,18 @@ namespace mindrecord {
void
BindSchema
(
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
Schema
,
std
::
shared_ptr
<
Schema
>>
(
*
m
,
"Schema"
,
py
::
module_local
())
.
def_static
(
"build"
,
(
std
::
shared_ptr
<
Schema
>
(
*
)(
std
::
string
,
py
::
handle
))
&
Schema
::
Build
)
.
def
(
"get_desc"
,
&
Schema
::
get_d
esc
)
.
def
(
"get_desc"
,
&
Schema
::
GetD
esc
)
.
def
(
"get_schema_content"
,
(
py
::
object
(
Schema
::*
)())
&
Schema
::
GetSchemaForPython
)
.
def
(
"get_blob_fields"
,
&
Schema
::
get_blob_f
ields
)
.
def
(
"get_schema_id"
,
&
Schema
::
get_schema_id
);
.
def
(
"get_blob_fields"
,
&
Schema
::
GetBlobF
ields
)
.
def
(
"get_schema_id"
,
&
Schema
::
GetSchemaID
);
}
void
BindStatistics
(
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
Statistics
,
std
::
shared_ptr
<
Statistics
>>
(
*
m
,
"Statistics"
,
py
::
module_local
())
.
def_static
(
"build"
,
(
std
::
shared_ptr
<
Statistics
>
(
*
)(
std
::
string
,
py
::
handle
))
&
Statistics
::
Build
)
.
def
(
"get_desc"
,
&
Statistics
::
get_d
esc
)
.
def
(
"get_desc"
,
&
Statistics
::
GetD
esc
)
.
def
(
"get_statistics"
,
(
py
::
object
(
Statistics
::*
)())
&
Statistics
::
GetStatisticsForPython
)
.
def
(
"get_statistics_id"
,
&
Statistics
::
get_statistics_id
);
.
def
(
"get_statistics_id"
,
&
Statistics
::
GetStatisticsID
);
}
void
BindShardHeader
(
const
py
::
module
*
m
)
{
...
...
@@ -60,9 +60,9 @@ void BindShardHeader(const py::module *m) {
.
def
(
"add_statistics"
,
&
ShardHeader
::
AddStatistic
)
.
def
(
"add_index_fields"
,
(
MSRStatus
(
ShardHeader
::*
)(
const
std
::
vector
<
std
::
string
>
&
))
&
ShardHeader
::
AddIndexFields
)
.
def
(
"get_meta"
,
&
ShardHeader
::
get_s
chemas
)
.
def
(
"get_statistics"
,
&
ShardHeader
::
get_s
tatistics
)
.
def
(
"get_fields"
,
&
ShardHeader
::
get_f
ields
)
.
def
(
"get_meta"
,
&
ShardHeader
::
GetS
chemas
)
.
def
(
"get_statistics"
,
&
ShardHeader
::
GetS
tatistics
)
.
def
(
"get_fields"
,
&
ShardHeader
::
GetF
ields
)
.
def
(
"get_schema_by_id"
,
&
ShardHeader
::
GetSchemaByID
)
.
def
(
"get_statistic_by_id"
,
&
ShardHeader
::
GetStatisticByID
);
}
...
...
@@ -72,8 +72,8 @@ void BindShardWriter(py::module *m) {
.
def
(
py
::
init
<>
())
.
def
(
"open"
,
&
ShardWriter
::
Open
)
.
def
(
"open_for_append"
,
&
ShardWriter
::
OpenForAppend
)
.
def
(
"set_header_size"
,
&
ShardWriter
::
set_header_s
ize
)
.
def
(
"set_page_size"
,
&
ShardWriter
::
set_page_s
ize
)
.
def
(
"set_header_size"
,
&
ShardWriter
::
SetHeaderS
ize
)
.
def
(
"set_page_size"
,
&
ShardWriter
::
SetPageS
ize
)
.
def
(
"set_shard_header"
,
&
ShardWriter
::
SetShardHeader
)
.
def
(
"write_raw_data"
,
(
MSRStatus
(
ShardWriter
::*
)(
std
::
map
<
uint64_t
,
std
::
vector
<
py
::
handle
>>
&
,
vector
<
vector
<
uint8_t
>>
&
,
bool
,
bool
))
&
...
...
@@ -88,8 +88,8 @@ void BindShardReader(const py::module *m) {
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
))
&
ShardReader
::
OpenPy
)
.
def
(
"launch"
,
&
ShardReader
::
Launch
)
.
def
(
"get_header"
,
&
ShardReader
::
get_shard_h
eader
)
.
def
(
"get_blob_fields"
,
&
ShardReader
::
get_blob_f
ields
)
.
def
(
"get_header"
,
&
ShardReader
::
GetShardH
eader
)
.
def
(
"get_blob_fields"
,
&
ShardReader
::
GetBlobF
ields
)
.
def
(
"get_next"
,
(
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
pybind11
::
object
>>
(
ShardReader
::*
)())
&
ShardReader
::
GetNextPy
)
.
def
(
"finish"
,
&
ShardReader
::
Finish
)
...
...
@@ -119,9 +119,9 @@ void BindShardSegment(py::module *m) {
.
def
(
"read_at_page_by_name"
,
(
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
pybind11
::
object
>>>
(
ShardSegment
::*
)(
std
::
string
,
int64_t
,
int64_t
))
&
ShardSegment
::
ReadAtPageByNamePy
)
.
def
(
"get_header"
,
&
ShardSegment
::
get_shard_h
eader
)
.
def
(
"get_header"
,
&
ShardSegment
::
GetShardH
eader
)
.
def
(
"get_blob_fields"
,
(
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
(
ShardSegment
::*
)())
&
ShardSegment
::
get_blob_f
ields
);
(
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
(
ShardSegment
::*
)())
&
ShardSegment
::
GetBlobF
ields
);
}
void
BindGlobalParams
(
py
::
module
*
m
)
{
...
...
mindspore/ccsrc/mindrecord/include/shard_category.h
浏览文件 @
cf352d19
...
...
@@ -36,7 +36,7 @@ class ShardCategory : public ShardOperator {
~
ShardCategory
()
override
{};
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
get_c
ategories
()
const
{
return
categories_
;
}
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
&
GetC
ategories
()
const
{
return
categories_
;
}
const
std
::
string
GetCategoryField
()
const
{
return
category_field_
;
}
...
...
@@ -46,7 +46,7 @@ class ShardCategory : public ShardOperator {
bool
GetReplacement
()
const
{
return
replacement_
;
}
MSRStatus
e
xecute
(
ShardTask
&
tasks
)
override
;
MSRStatus
E
xecute
(
ShardTask
&
tasks
)
override
;
int64_t
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
override
;
...
...
mindspore/ccsrc/mindrecord/include/shard_header.h
浏览文件 @
cf352d19
...
...
@@ -58,19 +58,19 @@ class ShardHeader {
/// \brief get the schema
/// \return the schema
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
get_s
chemas
();
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
GetS
chemas
();
/// \brief get Statistics
/// \return the Statistic
std
::
vector
<
std
::
shared_ptr
<
Statistics
>>
get_s
tatistics
();
std
::
vector
<
std
::
shared_ptr
<
Statistics
>>
GetS
tatistics
();
/// \brief get the fields of the index
/// \return the fields of the index
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
get_f
ields
();
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
GetF
ields
();
/// \brief get the index
/// \return the index
std
::
shared_ptr
<
Index
>
get_i
ndex
();
std
::
shared_ptr
<
Index
>
GetI
ndex
();
/// \brief get the schema by schemaid
/// \param[in] schemaId the id of schema needs to be got
...
...
@@ -80,7 +80,7 @@ class ShardHeader {
/// \brief get the filepath to shard by shardID
/// \param[in] shardID the id of shard which filepath needs to be obtained
/// \return the filepath obtained by shardID
std
::
string
get_shard_address_by_id
(
int64_t
shard_id
);
std
::
string
GetShardAddressByID
(
int64_t
shard_id
);
/// \brief get the statistic by statistic id
/// \param[in] statisticId the id of statistic needs to be get
...
...
@@ -89,7 +89,7 @@ class ShardHeader {
MSRStatus
InitByFiles
(
const
std
::
vector
<
std
::
string
>
&
file_paths
);
void
set_i
ndex
(
Index
index
)
{
index_
=
std
::
make_shared
<
Index
>
(
index
);
}
void
SetI
ndex
(
Index
index
)
{
index_
=
std
::
make_shared
<
Index
>
(
index
);
}
std
::
pair
<
std
::
shared_ptr
<
Page
>
,
MSRStatus
>
GetPage
(
const
int
&
shard_id
,
const
int
&
page_id
);
...
...
@@ -103,21 +103,21 @@ class ShardHeader {
const
std
::
pair
<
MSRStatus
,
std
::
shared_ptr
<
Page
>>
GetPageByGroupId
(
const
int
&
group_id
,
const
int
&
shard_id
);
std
::
vector
<
std
::
string
>
get_shard_a
ddresses
()
const
{
return
shard_addresses_
;
}
std
::
vector
<
std
::
string
>
GetShardA
ddresses
()
const
{
return
shard_addresses_
;
}
int
get_shard_c
ount
()
const
{
return
shard_count_
;
}
int
GetShardC
ount
()
const
{
return
shard_count_
;
}
int
get_schema_c
ount
()
const
{
return
schema_
.
size
();
}
int
GetSchemaC
ount
()
const
{
return
schema_
.
size
();
}
uint64_t
get_header_s
ize
()
const
{
return
header_size_
;
}
uint64_t
GetHeaderS
ize
()
const
{
return
header_size_
;
}
uint64_t
get_page_s
ize
()
const
{
return
page_size_
;
}
uint64_t
GetPageS
ize
()
const
{
return
page_size_
;
}
void
set_header_s
ize
(
const
uint64_t
&
header_size
)
{
header_size_
=
header_size
;
}
void
SetHeaderS
ize
(
const
uint64_t
&
header_size
)
{
header_size_
=
header_size
;
}
void
set_page_s
ize
(
const
uint64_t
&
page_size
)
{
page_size_
=
page_size
;
}
void
SetPageS
ize
(
const
uint64_t
&
page_size
)
{
page_size_
=
page_size
;
}
const
string
get_v
ersion
()
{
return
version_
;
}
const
string
GetV
ersion
()
{
return
version_
;
}
std
::
vector
<
std
::
string
>
SerializeHeader
();
...
...
@@ -132,7 +132,7 @@ class ShardHeader {
/// \param[in] the shard data real path
/// \param[in] the headers which readed from the shard data
/// \return SUCCESS/FAILED
MSRStatus
get_h
eaders
(
const
vector
<
string
>
&
real_addresses
,
std
::
vector
<
json
>
&
headers
);
MSRStatus
GetH
eaders
(
const
vector
<
string
>
&
real_addresses
,
std
::
vector
<
json
>
&
headers
);
MSRStatus
ValidateField
(
const
std
::
vector
<
std
::
string
>
&
field_name
,
json
schema
,
const
uint64_t
&
schema_id
);
...
...
mindspore/ccsrc/mindrecord/include/shard_index.h
浏览文件 @
cf352d19
...
...
@@ -52,7 +52,7 @@ class Index {
/// \brief get stored fields
/// \return fields stored
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>
>
get_f
ields
();
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>
>
GetF
ields
();
private:
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>
>
fields_
;
...
...
mindspore/ccsrc/mindrecord/include/shard_operator.h
浏览文件 @
cf352d19
...
...
@@ -26,23 +26,23 @@ class ShardOperator {
virtual
~
ShardOperator
()
=
default
;
MSRStatus
operator
()(
ShardTask
&
tasks
)
{
if
(
SUCCESS
!=
this
->
pre_e
xecute
(
tasks
))
{
if
(
SUCCESS
!=
this
->
PreE
xecute
(
tasks
))
{
return
FAILED
;
}
if
(
SUCCESS
!=
this
->
e
xecute
(
tasks
))
{
if
(
SUCCESS
!=
this
->
E
xecute
(
tasks
))
{
return
FAILED
;
}
if
(
SUCCESS
!=
this
->
suf_e
xecute
(
tasks
))
{
if
(
SUCCESS
!=
this
->
SufE
xecute
(
tasks
))
{
return
FAILED
;
}
return
SUCCESS
;
}
virtual
MSRStatus
pre_e
xecute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
virtual
MSRStatus
PreE
xecute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
virtual
MSRStatus
e
xecute
(
ShardTask
&
tasks
)
=
0
;
virtual
MSRStatus
E
xecute
(
ShardTask
&
tasks
)
=
0
;
virtual
MSRStatus
suf_e
xecute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
virtual
MSRStatus
SufE
xecute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
virtual
int64_t
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
{
return
-
1
;
}
};
...
...
mindspore/ccsrc/mindrecord/include/shard_page.h
浏览文件 @
cf352d19
...
...
@@ -53,29 +53,29 @@ class Page {
/// \return the json format of the page and its description
json
GetPage
()
const
;
int
get_page_id
()
const
{
return
page_id_
;
}
int
GetPageID
()
const
{
return
page_id_
;
}
int
get_shard_id
()
const
{
return
shard_id_
;
}
int
GetShardID
()
const
{
return
shard_id_
;
}
int
get_page_type_id
()
const
{
return
page_type_id_
;
}
int
GetPageTypeID
()
const
{
return
page_type_id_
;
}
std
::
string
get_page_t
ype
()
const
{
return
page_type_
;
}
std
::
string
GetPageT
ype
()
const
{
return
page_type_
;
}
uint64_t
get_page_s
ize
()
const
{
return
page_size_
;
}
uint64_t
GetPageS
ize
()
const
{
return
page_size_
;
}
uint64_t
get_start_row_id
()
const
{
return
start_row_id_
;
}
uint64_t
GetStartRowID
()
const
{
return
start_row_id_
;
}
uint64_t
get_end_row_id
()
const
{
return
end_row_id_
;
}
uint64_t
GetEndRowID
()
const
{
return
end_row_id_
;
}
void
set_end_row_id
(
const
uint64_t
&
end_row_id
)
{
end_row_id_
=
end_row_id
;
}
void
SetEndRowID
(
const
uint64_t
&
end_row_id
)
{
end_row_id_
=
end_row_id
;
}
void
set_page_s
ize
(
const
uint64_t
&
page_size
)
{
page_size_
=
page_size
;
}
void
SetPageS
ize
(
const
uint64_t
&
page_size
)
{
page_size_
=
page_size
;
}
std
::
pair
<
int
,
uint64_t
>
get_last_row_group_id
()
const
{
return
row_group_ids_
.
back
();
}
std
::
pair
<
int
,
uint64_t
>
GetLastRowGroupID
()
const
{
return
row_group_ids_
.
back
();
}
std
::
vector
<
std
::
pair
<
int
,
uint64_t
>>
get_row_group_i
ds
()
const
{
return
row_group_ids_
;
}
std
::
vector
<
std
::
pair
<
int
,
uint64_t
>>
GetRowGroupI
ds
()
const
{
return
row_group_ids_
;
}
void
set_row_group_i
ds
(
const
std
::
vector
<
std
::
pair
<
int
,
uint64_t
>>
&
last_row_group_ids
)
{
void
SetRowGroupI
ds
(
const
std
::
vector
<
std
::
pair
<
int
,
uint64_t
>>
&
last_row_group_ids
)
{
row_group_ids_
=
last_row_group_ids
;
}
...
...
mindspore/ccsrc/mindrecord/include/shard_pk_sample.h
浏览文件 @
cf352d19
...
...
@@ -37,7 +37,7 @@ class ShardPkSample : public ShardCategory {
~
ShardPkSample
()
override
{};
MSRStatus
suf_e
xecute
(
ShardTask
&
tasks
)
override
;
MSRStatus
SufE
xecute
(
ShardTask
&
tasks
)
override
;
private:
bool
shuffle_
;
...
...
mindspore/ccsrc/mindrecord/include/shard_reader.h
浏览文件 @
cf352d19
...
...
@@ -107,11 +107,11 @@ class ShardReader {
/// \brief aim to get the meta data
/// \return the metadata
std
::
shared_ptr
<
ShardHeader
>
get_shard_h
eader
()
const
;
std
::
shared_ptr
<
ShardHeader
>
GetShardH
eader
()
const
;
/// \brief get the number of shards
/// \return # of shards
int
get_shard_c
ount
()
const
;
int
GetShardC
ount
()
const
;
/// \brief get the number of rows in database
/// \param[in] file_path the path of ONE file, any file in dataset is fine
...
...
@@ -126,7 +126,7 @@ class ShardReader {
/// \brief get the number of rows in database
/// \return # of rows
int
get_num_r
ows
()
const
;
int
GetNumR
ows
()
const
;
/// \brief Read the summary of row groups
/// \return the tuple of 4 elements
...
...
@@ -185,7 +185,7 @@ class ShardReader {
/// \brief get blob filed list
/// \return blob field list
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
get_blob_f
ields
();
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
GetBlobF
ields
();
/// \brief reset reader
/// \return null
...
...
@@ -193,10 +193,10 @@ class ShardReader {
/// \brief set flag of all-in-index
/// \return null
void
set_all_in_i
ndex
(
bool
all_in_index
)
{
all_in_index_
=
all_in_index
;
}
void
SetAllInI
ndex
(
bool
all_in_index
)
{
all_in_index_
=
all_in_index
;
}
/// \brief get NLP flag
bool
get_nlp_f
lag
();
bool
GetNlpF
lag
();
/// \brief get all classes
MSRStatus
GetAllClasses
(
const
std
::
string
&
category_field
,
std
::
set
<
std
::
string
>
&
categories
);
...
...
mindspore/ccsrc/mindrecord/include/shard_sample.h
浏览文件 @
cf352d19
...
...
@@ -38,11 +38,11 @@ class ShardSample : public ShardOperator {
~
ShardSample
()
override
{};
const
std
::
pair
<
int
,
int
>
get_p
artitions
()
const
;
const
std
::
pair
<
int
,
int
>
GetP
artitions
()
const
;
MSRStatus
e
xecute
(
ShardTask
&
tasks
)
override
;
MSRStatus
E
xecute
(
ShardTask
&
tasks
)
override
;
MSRStatus
suf_e
xecute
(
ShardTask
&
tasks
)
override
;
MSRStatus
SufE
xecute
(
ShardTask
&
tasks
)
override
;
int64_t
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
override
;
...
...
mindspore/ccsrc/mindrecord/include/shard_schema.h
浏览文件 @
cf352d19
...
...
@@ -51,7 +51,7 @@ class Schema {
/// \brief get the schema and its description
/// \return the json format of the schema and its description
std
::
string
get_d
esc
()
const
;
std
::
string
GetD
esc
()
const
;
/// \brief get the schema and its description
/// \return the json format of the schema and its description
...
...
@@ -63,15 +63,15 @@ class Schema {
/// set the schema id
/// \param[in] id the id need to be set
void
set_schema_id
(
int64_t
id
);
void
SetSchemaID
(
int64_t
id
);
/// get the schema id
/// \return the int64 schema id
int64_t
get_schema_id
()
const
;
int64_t
GetSchemaID
()
const
;
/// get the blob fields
/// \return the vector<string> blob fields
std
::
vector
<
std
::
string
>
get_blob_f
ields
()
const
;
std
::
vector
<
std
::
string
>
GetBlobF
ields
()
const
;
private:
Schema
()
=
default
;
...
...
mindspore/ccsrc/mindrecord/include/shard_segment.h
浏览文件 @
cf352d19
...
...
@@ -81,7 +81,7 @@ class ShardSegment : public ShardReader {
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
pybind11
::
object
>>>
ReadAtPageByNamePy
(
std
::
string
category_name
,
int64_t
page_no
,
int64_t
n_rows_of_page
);
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
get_blob_f
ields
();
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
GetBlobF
ields
();
private:
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
tuple
<
int
,
std
::
string
,
int
>>>
WrapCategoryInfo
();
...
...
mindspore/ccsrc/mindrecord/include/shard_shuffle.h
浏览文件 @
cf352d19
...
...
@@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator {
~
ShardShuffle
()
override
{};
MSRStatus
e
xecute
(
ShardTask
&
tasks
)
override
;
MSRStatus
E
xecute
(
ShardTask
&
tasks
)
override
;
private:
uint32_t
shuffle_seed_
;
...
...
mindspore/ccsrc/mindrecord/include/shard_statistics.h
浏览文件 @
cf352d19
...
...
@@ -53,11 +53,11 @@ class Statistics {
/// \brief get the description
/// \return the description
std
::
string
get_d
esc
()
const
;
std
::
string
GetD
esc
()
const
;
/// \brief get the statistic
/// \return json format of the statistic
json
get_s
tatistics
()
const
;
json
GetS
tatistics
()
const
;
/// \brief get the statistic for python
/// \return the python object of statistics
...
...
@@ -66,11 +66,11 @@ class Statistics {
/// \brief decode the bson statistics to json
/// \param[in] encodedStatistics the bson type of statistics
/// \return json type of statistic
void
set_statistics_id
(
int64_t
id
);
void
SetStatisticsID
(
int64_t
id
);
/// \brief get the statistics id
/// \return the int64 statistics id
int64_t
get_statistics_id
()
const
;
int64_t
GetStatisticsID
()
const
;
private:
/// \brief validate the statistic
...
...
mindspore/ccsrc/mindrecord/include/shard_task.h
浏览文件 @
cf352d19
...
...
@@ -39,9 +39,9 @@ class ShardTask {
uint32_t
SizeOfRows
()
const
;
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
get_task_by_id
(
size_t
id
);
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
GetTaskByID
(
size_t
id
);
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
get_random_t
ask
();
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
GetRandomT
ask
();
static
ShardTask
Combine
(
std
::
vector
<
ShardTask
>
&
category_tasks
,
bool
replacement
,
int64_t
num_elements
);
...
...
mindspore/ccsrc/mindrecord/include/shard_writer.h
浏览文件 @
cf352d19
...
...
@@ -69,12 +69,12 @@ class ShardWriter {
/// \brief Set file size
/// \param[in] header_size the size of header, only (1<<N) is accepted
/// \return MSRStatus the status of MSRStatus
MSRStatus
set_header_s
ize
(
const
uint64_t
&
header_size
);
MSRStatus
SetHeaderS
ize
(
const
uint64_t
&
header_size
);
/// \brief Set page size
/// \param[in] page_size the size of page, only (1<<N) is accepted
/// \return MSRStatus the status of MSRStatus
MSRStatus
set_page_s
ize
(
const
uint64_t
&
page_size
);
MSRStatus
SetPageS
ize
(
const
uint64_t
&
page_size
);
/// \brief Set shard header
/// \param[in] header_data the info of header
...
...
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
浏览文件 @
cf352d19
...
...
@@ -64,7 +64,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const str
}
// schema does not contain the field
auto
schema
=
shard_header_
.
get_s
chemas
()[
0
]
->
GetSchema
()[
"schema"
];
auto
schema
=
shard_header_
.
GetS
chemas
()[
0
]
->
GetSchema
()[
"schema"
];
if
(
schema
.
find
(
field
)
==
schema
.
end
())
{
MS_LOG
(
ERROR
)
<<
"The field "
<<
field
<<
" is not found in schema "
<<
schema
;
return
{
FAILED
,
""
};
...
...
@@ -203,7 +203,7 @@ MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::stri
}
std
::
pair
<
MSRStatus
,
sqlite3
*>
ShardIndexGenerator
::
CreateDatabase
(
int
shard_no
)
{
std
::
string
shard_address
=
shard_header_
.
get_shard_address_by_id
(
shard_no
);
std
::
string
shard_address
=
shard_header_
.
GetShardAddressByID
(
shard_no
);
if
(
shard_address
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"Shard address is null, shard no: "
<<
shard_no
;
return
{
FAILED
,
nullptr
};
...
...
@@ -357,12 +357,12 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
MSRStatus
ShardIndexGenerator
::
AddBlobPageInfo
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
std
::
string
,
std
::
string
>>
&
row_data
,
const
std
::
shared_ptr
<
Page
>
cur_blob_page
,
uint64_t
&
cur_blob_page_offset
,
std
::
fstream
&
in
)
{
row_data
.
emplace_back
(
":PAGE_ID_BLOB"
,
"INTEGER"
,
std
::
to_string
(
cur_blob_page
->
get_page_id
()));
row_data
.
emplace_back
(
":PAGE_ID_BLOB"
,
"INTEGER"
,
std
::
to_string
(
cur_blob_page
->
GetPageID
()));
// blob data start
row_data
.
emplace_back
(
":PAGE_OFFSET_BLOB"
,
"INTEGER"
,
std
::
to_string
(
cur_blob_page_offset
));
auto
&
io_seekg_blob
=
in
.
seekg
(
page_size_
*
cur_blob_page
->
get_page_id
()
+
header_size_
+
cur_blob_page_offset
,
std
::
ios
::
beg
);
in
.
seekg
(
page_size_
*
cur_blob_page
->
GetPageID
()
+
header_size_
+
cur_blob_page_offset
,
std
::
ios
::
beg
);
if
(
!
io_seekg_blob
.
good
()
||
io_seekg_blob
.
fail
()
||
io_seekg_blob
.
bad
())
{
MS_LOG
(
ERROR
)
<<
"File seekg failed"
;
in
.
close
();
...
...
@@ -405,7 +405,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
std
::
shared_ptr
<
Page
>
cur_raw_page
=
shard_header_
.
GetPage
(
shard_no
,
raw_page_id
).
first
;
// related blob page
vector
<
pair
<
int
,
uint64_t
>>
row_group_list
=
cur_raw_page
->
get_row_group_i
ds
();
vector
<
pair
<
int
,
uint64_t
>>
row_group_list
=
cur_raw_page
->
GetRowGroupI
ds
();
// pair: row_group id, offset in raw data page
for
(
pair
<
int
,
int
>
blob_ids
:
row_group_list
)
{
...
...
@@ -415,18 +415,18 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
// offset in current raw data page
auto
cur_raw_page_offset
=
static_cast
<
uint64_t
>
(
blob_ids
.
second
);
uint64_t
cur_blob_page_offset
=
0
;
for
(
unsigned
int
i
=
cur_blob_page
->
get_start_row_id
();
i
<
cur_blob_page
->
get_end_row_id
();
++
i
)
{
for
(
unsigned
int
i
=
cur_blob_page
->
GetStartRowID
();
i
<
cur_blob_page
->
GetEndRowID
();
++
i
)
{
std
::
vector
<
std
::
tuple
<
std
::
string
,
std
::
string
,
std
::
string
>>
row_data
;
row_data
.
emplace_back
(
":ROW_ID"
,
"INTEGER"
,
std
::
to_string
(
i
));
row_data
.
emplace_back
(
":ROW_GROUP_ID"
,
"INTEGER"
,
std
::
to_string
(
cur_blob_page
->
get_page_type_id
()));
row_data
.
emplace_back
(
":PAGE_ID_RAW"
,
"INTEGER"
,
std
::
to_string
(
cur_raw_page
->
get_page_id
()));
row_data
.
emplace_back
(
":ROW_GROUP_ID"
,
"INTEGER"
,
std
::
to_string
(
cur_blob_page
->
GetPageTypeID
()));
row_data
.
emplace_back
(
":PAGE_ID_RAW"
,
"INTEGER"
,
std
::
to_string
(
cur_raw_page
->
GetPageID
()));
// raw data start
row_data
.
emplace_back
(
":PAGE_OFFSET_RAW"
,
"INTEGER"
,
std
::
to_string
(
cur_raw_page_offset
));
// calculate raw data end
auto
&
io_seekg
=
in
.
seekg
(
page_size_
*
(
cur_raw_page
->
get_page_id
())
+
header_size_
+
cur_raw_page_offset
,
std
::
ios
::
beg
);
in
.
seekg
(
page_size_
*
(
cur_raw_page
->
GetPageID
())
+
header_size_
+
cur_raw_page_offset
,
std
::
ios
::
beg
);
if
(
!
io_seekg
.
good
()
||
io_seekg
.
fail
()
||
io_seekg
.
bad
())
{
MS_LOG
(
ERROR
)
<<
"File seekg failed"
;
in
.
close
();
...
...
@@ -473,7 +473,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
INDEX_FIELDS
ShardIndexGenerator
::
GenerateIndexFields
(
const
std
::
vector
<
json
>
&
schema_detail
)
{
std
::
vector
<
std
::
tuple
<
std
::
string
,
std
::
string
,
std
::
string
>>
fields
;
// index fields
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
index_fields
=
shard_header_
.
get_f
ields
();
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
index_fields
=
shard_header_
.
GetF
ields
();
for
(
const
auto
&
field
:
index_fields
)
{
if
(
field
.
first
>=
schema_detail
.
size
())
{
return
{
FAILED
,
{}};
...
...
@@ -504,7 +504,7 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std
const
std
::
vector
<
int
>
&
raw_page_ids
,
const
std
::
map
<
int
,
int
>
&
blob_id_to_page_id
)
{
// Add index data to database
std
::
string
shard_address
=
shard_header_
.
get_shard_address_by_id
(
shard_no
);
std
::
string
shard_address
=
shard_header_
.
GetShardAddressByID
(
shard_no
);
if
(
shard_address
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"Shard address is null"
;
return
FAILED
;
...
...
@@ -546,12 +546,12 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std
}
MSRStatus
ShardIndexGenerator
::
WriteToDatabase
()
{
fields_
=
shard_header_
.
get_f
ields
();
page_size_
=
shard_header_
.
get_page_s
ize
();
header_size_
=
shard_header_
.
get_header_s
ize
();
schema_count_
=
shard_header_
.
get_schema_c
ount
();
if
(
shard_header_
.
get_shard_c
ount
()
>
kMaxShardCount
)
{
MS_LOG
(
ERROR
)
<<
"num shards: "
<<
shard_header_
.
get_shard_c
ount
()
<<
" exceeds max count:"
<<
kMaxSchemaCount
;
fields_
=
shard_header_
.
GetF
ields
();
page_size_
=
shard_header_
.
GetPageS
ize
();
header_size_
=
shard_header_
.
GetHeaderS
ize
();
schema_count_
=
shard_header_
.
GetSchemaC
ount
();
if
(
shard_header_
.
GetShardC
ount
()
>
kMaxShardCount
)
{
MS_LOG
(
ERROR
)
<<
"num shards: "
<<
shard_header_
.
GetShardC
ount
()
<<
" exceeds max count:"
<<
kMaxSchemaCount
;
return
FAILED
;
}
task_
=
0
;
// set two atomic vars to initial value
...
...
@@ -559,7 +559,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
// spawn half the physical threads or total number of shards whichever is smaller
const
unsigned
int
num_workers
=
std
::
min
(
std
::
thread
::
hardware_concurrency
()
/
2
+
1
,
static_cast
<
unsigned
int
>
(
shard_header_
.
get_shard_c
ount
()));
std
::
min
(
std
::
thread
::
hardware_concurrency
()
/
2
+
1
,
static_cast
<
unsigned
int
>
(
shard_header_
.
GetShardC
ount
()));
std
::
vector
<
std
::
thread
>
threads
;
threads
.
reserve
(
num_workers
);
...
...
@@ -576,7 +576,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
void
ShardIndexGenerator
::
DatabaseWriter
()
{
int
shard_no
=
task_
++
;
while
(
shard_no
<
shard_header_
.
get_shard_c
ount
())
{
while
(
shard_no
<
shard_header_
.
GetShardC
ount
())
{
auto
db
=
CreateDatabase
(
shard_no
);
if
(
db
.
first
!=
SUCCESS
||
db
.
second
==
nullptr
||
write_success_
==
false
)
{
write_success_
=
false
;
...
...
@@ -592,10 +592,10 @@ void ShardIndexGenerator::DatabaseWriter() {
std
::
vector
<
int
>
raw_page_ids
;
for
(
uint64_t
i
=
0
;
i
<
total_pages
;
++
i
)
{
std
::
shared_ptr
<
Page
>
cur_page
=
shard_header_
.
GetPage
(
shard_no
,
i
).
first
;
if
(
cur_page
->
get_page_t
ype
()
==
"RAW_DATA"
)
{
if
(
cur_page
->
GetPageT
ype
()
==
"RAW_DATA"
)
{
raw_page_ids
.
push_back
(
i
);
}
else
if
(
cur_page
->
get_page_t
ype
()
==
"BLOB_DATA"
)
{
blob_id_to_page_id
[
cur_page
->
get_page_type_id
()]
=
i
;
}
else
if
(
cur_page
->
GetPageT
ype
()
==
"BLOB_DATA"
)
{
blob_id_to_page_id
[
cur_page
->
GetPageTypeID
()]
=
i
;
}
}
...
...
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
cf352d19
...
...
@@ -56,9 +56,9 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
return
FAILED
;
}
shard_header_
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
header_size_
=
shard_header_
->
get_header_s
ize
();
page_size_
=
shard_header_
->
get_page_s
ize
();
file_paths_
=
shard_header_
->
get_shard_a
ddresses
();
header_size_
=
shard_header_
->
GetHeaderS
ize
();
page_size_
=
shard_header_
->
GetPageS
ize
();
file_paths_
=
shard_header_
->
GetShardA
ddresses
();
for
(
const
auto
&
file
:
file_paths_
)
{
sqlite3
*
db
=
nullptr
;
...
...
@@ -105,7 +105,7 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
MSRStatus
ShardReader
::
CheckColumnList
(
const
std
::
vector
<
std
::
string
>
&
selected_columns
)
{
vector
<
int
>
inSchema
(
selected_columns
.
size
(),
0
);
for
(
auto
&
p
:
get_shard_header
()
->
get_s
chemas
())
{
for
(
auto
&
p
:
GetShardHeader
()
->
GetS
chemas
())
{
auto
schema
=
p
->
GetSchema
()[
"schema"
];
for
(
unsigned
int
i
=
0
;
i
<
selected_columns
.
size
();
++
i
)
{
if
(
schema
.
find
(
selected_columns
[
i
])
!=
schema
.
end
())
{
...
...
@@ -183,15 +183,15 @@ void ShardReader::Close() {
FileStreamsOperator
();
}
std
::
shared_ptr
<
ShardHeader
>
ShardReader
::
get_shard_h
eader
()
const
{
return
shard_header_
;
}
std
::
shared_ptr
<
ShardHeader
>
ShardReader
::
GetShardH
eader
()
const
{
return
shard_header_
;
}
int
ShardReader
::
get_shard_count
()
const
{
return
shard_header_
->
get_shard_c
ount
();
}
int
ShardReader
::
GetShardCount
()
const
{
return
shard_header_
->
GetShardC
ount
();
}
int
ShardReader
::
get_num_r
ows
()
const
{
return
num_rows_
;
}
int
ShardReader
::
GetNumR
ows
()
const
{
return
num_rows_
;
}
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
ShardReader
::
ReadRowGroupSummary
()
{
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
row_group_summary
;
int
shard_count
=
shard_header_
->
get_shard_c
ount
();
int
shard_count
=
shard_header_
->
GetShardC
ount
();
if
(
shard_count
<=
0
)
{
return
row_group_summary
;
}
...
...
@@ -205,13 +205,13 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
for
(
uint64_t
page_id
=
0
;
page_id
<=
last_page_id
;
++
page_id
)
{
const
auto
&
page_t
=
shard_header_
->
GetPage
(
shard_id
,
page_id
);
const
auto
&
page
=
page_t
.
first
;
if
(
page
->
get_page_t
ype
()
!=
kPageTypeBlob
)
continue
;
uint64_t
start_row_id
=
page
->
get_start_row_id
();
if
(
start_row_id
>
page
->
get_end_row_id
())
{
if
(
page
->
GetPageT
ype
()
!=
kPageTypeBlob
)
continue
;
uint64_t
start_row_id
=
page
->
GetStartRowID
();
if
(
start_row_id
>
page
->
GetEndRowID
())
{
return
std
::
vector
<
std
::
tuple
<
int
,
int
,
int
,
uint64_t
>>
();
}
uint64_t
number_of_rows
=
page
->
get_end_row_id
()
-
start_row_id
;
row_group_summary
.
emplace_back
(
shard_id
,
page
->
get_page_type_id
(),
start_row_id
,
number_of_rows
);
uint64_t
number_of_rows
=
page
->
GetEndRowID
()
-
start_row_id
;
row_group_summary
.
emplace_back
(
shard_id
,
page
->
GetPageTypeID
(),
start_row_id
,
number_of_rows
);
}
}
}
...
...
@@ -265,7 +265,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
json
construct_json
;
for
(
unsigned
int
j
=
0
;
j
<
columns
.
size
();
++
j
)
{
// construct json "f1": value
auto
schema
=
shard_header_
->
get_s
chemas
()[
0
]
->
GetSchema
()[
"schema"
];
auto
schema
=
shard_header_
->
GetS
chemas
()[
0
]
->
GetSchema
()[
"schema"
];
// convert the string to base type by schema
if
(
schema
[
columns
[
j
]][
"type"
]
==
"int32"
)
{
...
...
@@ -317,7 +317,7 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
MSRStatus
ShardReader
::
GetAllClasses
(
const
std
::
string
&
category_field
,
std
::
set
<
std
::
string
>
&
categories
)
{
std
::
map
<
std
::
string
,
uint64_t
>
index_columns
;
for
(
auto
&
field
:
get_shard_header
()
->
get_f
ields
())
{
for
(
auto
&
field
:
GetShardHeader
()
->
GetF
ields
())
{
index_columns
[
field
.
second
]
=
field
.
first
;
}
if
(
index_columns
.
find
(
category_field
)
==
index_columns
.
end
())
{
...
...
@@ -400,11 +400,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const
}
const
std
::
shared_ptr
<
Page
>
&
page
=
ret
.
second
;
std
::
string
file_name
=
file_paths_
[
shard_id
];
uint64_t
page_length
=
page
->
get_page_s
ize
();
uint64_t
page_offset
=
page_size_
*
page
->
get_page_id
()
+
header_size_
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
image_offset
=
GetImageOffset
(
page
->
get_page_id
(),
shard_id
);
uint64_t
page_length
=
page
->
GetPageS
ize
();
uint64_t
page_offset
=
page_size_
*
page
->
GetPageID
()
+
header_size_
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
image_offset
=
GetImageOffset
(
page
->
GetPageID
(),
shard_id
);
auto
status_labels
=
GetLabels
(
page
->
get_page_id
(),
shard_id
,
columns
);
auto
status_labels
=
GetLabels
(
page
->
GetPageID
(),
shard_id
,
columns
);
if
(
status_labels
.
first
!=
SUCCESS
)
{
return
std
::
make_tuple
(
FAILED
,
""
,
0
,
0
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
(),
std
::
vector
<
json
>
());
}
...
...
@@ -426,11 +426,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
}
const
std
::
shared_ptr
<
Page
>
&
page
=
ret
.
second
;
std
::
string
file_name
=
file_paths_
[
shard_id
];
uint64_t
page_length
=
page
->
get_page_s
ize
();
uint64_t
page_offset
=
page_size_
*
page
->
get_page_id
()
+
header_size_
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
image_offset
=
GetImageOffset
(
page
->
get_page_id
(),
shard_id
,
criteria
);
uint64_t
page_length
=
page
->
GetPageS
ize
();
uint64_t
page_offset
=
page_size_
*
page
->
GetPageID
()
+
header_size_
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
image_offset
=
GetImageOffset
(
page
->
GetPageID
(),
shard_id
,
criteria
);
auto
status_labels
=
GetLabels
(
page
->
get_page_id
(),
shard_id
,
columns
,
criteria
);
auto
status_labels
=
GetLabels
(
page
->
GetPageID
(),
shard_id
,
columns
,
criteria
);
if
(
status_labels
.
first
!=
SUCCESS
)
{
return
std
::
make_tuple
(
FAILED
,
""
,
0
,
0
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
(),
std
::
vector
<
json
>
());
}
...
...
@@ -458,7 +458,7 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
// whether use index search
if
(
!
criteria
.
first
.
empty
())
{
auto
schema
=
shard_header_
->
get_s
chemas
()[
0
]
->
GetSchema
();
auto
schema
=
shard_header_
->
GetS
chemas
()[
0
]
->
GetSchema
();
// not number field should add '' in sql
if
(
kNumberFieldTypeSet
.
find
(
schema
[
"schema"
][
criteria
.
first
][
"type"
])
!=
kNumberFieldTypeSet
.
end
())
{
...
...
@@ -497,13 +497,13 @@ void ShardReader::CheckNlp() {
return
;
}
bool
ShardReader
::
get_nlp_f
lag
()
{
return
nlp_
;
}
bool
ShardReader
::
GetNlpF
lag
()
{
return
nlp_
;
}
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
ShardReader
::
get_blob_f
ields
()
{
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
ShardReader
::
GetBlobF
ields
()
{
std
::
vector
<
std
::
string
>
blob_fields
;
for
(
auto
&
p
:
get_shard_header
()
->
get_s
chemas
())
{
for
(
auto
&
p
:
GetShardHeader
()
->
GetS
chemas
())
{
// assume one schema
const
auto
&
fields
=
p
->
get_blob_f
ields
();
const
auto
&
fields
=
p
->
GetBlobF
ields
();
blob_fields
.
assign
(
fields
.
begin
(),
fields
.
end
());
break
;
}
...
...
@@ -516,7 +516,7 @@ void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns)
all_in_index_
=
false
;
return
;
}
for
(
auto
&
field
:
get_shard_header
()
->
get_f
ields
())
{
for
(
auto
&
field
:
GetShardHeader
()
->
GetF
ields
())
{
column_schema_id_
[
field
.
second
]
=
field
.
first
;
}
for
(
auto
&
col
:
columns
)
{
...
...
@@ -671,7 +671,7 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
json
construct_json
;
for
(
unsigned
int
j
=
0
;
j
<
columns
.
size
();
++
j
)
{
// construct json "f1": value
auto
schema
=
shard_header_
->
get_s
chemas
()[
0
]
->
GetSchema
()[
"schema"
];
auto
schema
=
shard_header_
->
GetS
chemas
()[
0
]
->
GetSchema
()[
"schema"
];
// convert the string to base type by schema
if
(
schema
[
columns
[
j
]][
"type"
]
==
"int32"
)
{
...
...
@@ -719,9 +719,9 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
return
-
1
;
}
auto
header
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
auto
file_paths
=
header
->
get_shard_a
ddresses
();
auto
file_paths
=
header
->
GetShardA
ddresses
();
auto
shard_count
=
file_paths
.
size
();
auto
index_fields
=
header
->
get_f
ields
();
auto
index_fields
=
header
->
GetF
ields
();
std
::
map
<
std
::
string
,
int64_t
>
map_schema_id_fields
;
for
(
auto
&
field
:
index_fields
)
{
...
...
@@ -799,7 +799,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
if
(
nlp_
)
{
selected_columns_
=
selected_columns
;
}
else
{
vector
<
std
::
string
>
blob_fields
=
get_blob_f
ields
().
second
;
vector
<
std
::
string
>
blob_fields
=
GetBlobF
ields
().
second
;
for
(
unsigned
int
i
=
0
;
i
<
selected_columns
.
size
();
++
i
)
{
if
(
!
std
::
any_of
(
blob_fields
.
begin
(),
blob_fields
.
end
(),
[
&
selected_columns
,
i
](
std
::
string
item
)
{
return
selected_columns
[
i
]
==
item
;
}))
{
...
...
@@ -846,7 +846,7 @@ MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consume
}
// should remove blob field from selected_columns when call from python
std
::
vector
<
std
::
string
>
columns
(
selected_columns
);
auto
blob_fields
=
get_blob_f
ields
().
second
;
auto
blob_fields
=
GetBlobF
ields
().
second
;
for
(
auto
&
blob_field
:
blob_fields
)
{
auto
it
=
std
::
find
(
selected_columns
.
begin
(),
selected_columns
.
end
(),
blob_field
);
if
(
it
!=
selected_columns
.
end
())
{
...
...
@@ -909,7 +909,7 @@ vector<std::string> ShardReader::GetAllColumns() {
vector
<
std
::
string
>
columns
;
if
(
nlp_
)
{
for
(
auto
&
c
:
selected_columns_
)
{
for
(
auto
&
p
:
get_shard_header
()
->
get_s
chemas
())
{
for
(
auto
&
p
:
GetShardHeader
()
->
GetS
chemas
())
{
auto
schema
=
p
->
GetSchema
()[
"schema"
];
// make sure schema is not reference since error occurred in arm.
for
(
auto
it
=
schema
.
begin
();
it
!=
schema
.
end
();
++
it
)
{
if
(
it
.
key
()
==
c
)
{
...
...
@@ -943,7 +943,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
CheckIfColumnInIndex
(
columns
);
auto
category_op
=
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
);
auto
categories
=
category_op
->
get_c
ategories
();
auto
categories
=
category_op
->
GetC
ategories
();
int64_t
num_elements
=
category_op
->
GetNumElements
();
if
(
num_elements
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"Parameter num_element is not positive"
;
...
...
@@ -1104,7 +1104,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
}
// Pick up task from task list
auto
task
=
tasks_
.
get_task_by_id
(
tasks_
.
permutation_
[
task_id
]);
auto
task
=
tasks_
.
GetTaskByID
(
tasks_
.
permutation_
[
task_id
]);
auto
shard_id
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
task
));
auto
group_id
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
task
));
...
...
@@ -1117,7 +1117,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
// Pack image list
std
::
vector
<
uint8_t
>
images
(
addr
[
1
]
-
addr
[
0
]);
auto
file_offset
=
header_size_
+
page_size_
*
(
page
->
get_page_id
())
+
addr
[
0
];
auto
file_offset
=
header_size_
+
page_size_
*
(
page
->
GetPageID
())
+
addr
[
0
];
auto
&
io_seekg
=
file_streams_random_
[
consumer_id
][
shard_id
]
->
seekg
(
file_offset
,
std
::
ios
::
beg
);
if
(
!
io_seekg
.
good
()
||
io_seekg
.
fail
()
||
io_seekg
.
bad
())
{
...
...
@@ -1139,7 +1139,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
if
(
selected_columns_
.
size
()
==
0
)
{
images_with_exact_columns
=
images
;
}
else
{
auto
blob_fields
=
get_blob_f
ields
();
auto
blob_fields
=
GetBlobF
ields
();
std
::
vector
<
uint32_t
>
ordered_selected_columns_index
;
uint32_t
index
=
0
;
...
...
@@ -1272,7 +1272,7 @@ MSRStatus ShardReader::ConsumerByBlock(int consumer_id) {
}
// Pick up task from task list
auto
task
=
tasks_
.
get_task_by_id
(
tasks_
.
permutation_
[
task_id
]);
auto
task
=
tasks_
.
GetTaskByID
(
tasks_
.
permutation_
[
task_id
]);
auto
shard_id
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
task
));
auto
group_id
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
task
));
...
...
mindspore/ccsrc/mindrecord/io/shard_segment.cc
浏览文件 @
cf352d19
...
...
@@ -28,7 +28,7 @@ using mindspore::MsLogLevel::INFO;
namespace
mindspore
{
namespace
mindrecord
{
ShardSegment
::
ShardSegment
()
{
set_all_in_i
ndex
(
false
);
}
ShardSegment
::
ShardSegment
()
{
SetAllInI
ndex
(
false
);
}
std
::
pair
<
MSRStatus
,
vector
<
std
::
string
>>
ShardSegment
::
GetCategoryFields
()
{
// Skip if already populated
...
...
@@ -211,7 +211,7 @@ std::pair<MSRStatus, std::vector<uint8_t>> ShardSegment::PackImages(int group_id
// Pack image list
std
::
vector
<
uint8_t
>
images
(
offset
[
1
]
-
offset
[
0
]);
auto
file_offset
=
header_size_
+
page_size_
*
(
blob_page
->
get_page_id
())
+
offset
[
0
];
auto
file_offset
=
header_size_
+
page_size_
*
(
blob_page
->
GetPageID
())
+
offset
[
0
];
auto
&
io_seekg
=
file_streams_random_
[
0
][
shard_id
]
->
seekg
(
file_offset
,
std
::
ios
::
beg
);
if
(
!
io_seekg
.
good
()
||
io_seekg
.
fail
()
||
io_seekg
.
bad
())
{
MS_LOG
(
ERROR
)
<<
"File seekg failed"
;
...
...
@@ -363,21 +363,21 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::obje
return
{
SUCCESS
,
std
::
move
(
json_data
)};
}
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
ShardSegment
::
get_blob_f
ields
()
{
std
::
pair
<
ShardType
,
std
::
vector
<
std
::
string
>>
ShardSegment
::
GetBlobF
ields
()
{
std
::
vector
<
std
::
string
>
blob_fields
;
for
(
auto
&
p
:
get_shard_header
()
->
get_s
chemas
())
{
for
(
auto
&
p
:
GetShardHeader
()
->
GetS
chemas
())
{
// assume one schema
const
auto
&
fields
=
p
->
get_blob_f
ields
();
const
auto
&
fields
=
p
->
GetBlobF
ields
();
blob_fields
.
assign
(
fields
.
begin
(),
fields
.
end
());
break
;
}
return
std
::
make_pair
(
get_nlp_f
lag
()
?
kNLP
:
kCV
,
blob_fields
);
return
std
::
make_pair
(
GetNlpF
lag
()
?
kNLP
:
kCV
,
blob_fields
);
}
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>
ShardSegment
::
GetImageLabel
(
std
::
vector
<
uint8_t
>
images
,
json
label
)
{
if
(
get_nlp_f
lag
())
{
if
(
GetNlpF
lag
())
{
vector
<
std
::
string
>
columns
;
for
(
auto
&
p
:
get_shard_header
()
->
get_s
chemas
())
{
for
(
auto
&
p
:
GetShardHeader
()
->
GetS
chemas
())
{
auto
schema
=
p
->
GetSchema
()[
"schema"
];
// make sure schema is not reference since error occurred in arm.
auto
schema_items
=
schema
.
items
();
using
it_type
=
decltype
(
schema_items
.
begin
());
...
...
mindspore/ccsrc/mindrecord/io/shard_writer.cc
浏览文件 @
cf352d19
...
...
@@ -179,12 +179,12 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
return
FAILED
;
}
shard_header_
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
auto
paths
=
shard_header_
->
get_shard_a
ddresses
();
MSRStatus
ret
=
set_header_size
(
shard_header_
->
get_header_s
ize
());
auto
paths
=
shard_header_
->
GetShardA
ddresses
();
MSRStatus
ret
=
SetHeaderSize
(
shard_header_
->
GetHeaderS
ize
());
if
(
ret
==
FAILED
)
{
return
FAILED
;
}
ret
=
set_page_size
(
shard_header_
->
get_page_s
ize
());
ret
=
SetPageSize
(
shard_header_
->
GetPageS
ize
());
if
(
ret
==
FAILED
)
{
return
FAILED
;
}
...
...
@@ -229,10 +229,10 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
}
// set fields in mindrecord when empty
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields
=
header_data
->
get_f
ields
();
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields
=
header_data
->
GetF
ields
();
if
(
fields
.
empty
())
{
MS_LOG
(
DEBUG
)
<<
"Missing index fields by user, auto generate index fields."
;
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
schemas
=
header_data
->
get_s
chemas
();
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
schemas
=
header_data
->
GetS
chemas
();
for
(
const
auto
&
schema
:
schemas
)
{
json
jsonSchema
=
schema
->
GetSchema
()[
"schema"
];
for
(
const
auto
&
el
:
jsonSchema
.
items
())
{
...
...
@@ -241,7 +241,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
(
el
.
value
()[
"type"
]
==
"int64"
&&
el
.
value
().
find
(
"shape"
)
==
el
.
value
().
end
())
||
(
el
.
value
()[
"type"
]
==
"float32"
&&
el
.
value
().
find
(
"shape"
)
==
el
.
value
().
end
())
||
(
el
.
value
()[
"type"
]
==
"float64"
&&
el
.
value
().
find
(
"shape"
)
==
el
.
value
().
end
()))
{
fields
.
emplace_back
(
std
::
make_pair
(
schema
->
get_schema_id
(),
el
.
key
()));
fields
.
emplace_back
(
std
::
make_pair
(
schema
->
GetSchemaID
(),
el
.
key
()));
}
}
}
...
...
@@ -256,12 +256,12 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
}
shard_header_
=
header_data
;
shard_header_
->
set_header_s
ize
(
header_size_
);
shard_header_
->
set_page_s
ize
(
page_size_
);
shard_header_
->
SetHeaderS
ize
(
header_size_
);
shard_header_
->
SetPageS
ize
(
page_size_
);
return
SUCCESS
;
}
MSRStatus
ShardWriter
::
set_header_s
ize
(
const
uint64_t
&
header_size
)
{
MSRStatus
ShardWriter
::
SetHeaderS
ize
(
const
uint64_t
&
header_size
)
{
// header_size [16KB, 128MB]
if
(
header_size
<
kMinHeaderSize
||
header_size
>
kMaxHeaderSize
)
{
MS_LOG
(
ERROR
)
<<
"Header size should between 16KB and 128MB."
;
...
...
@@ -276,7 +276,7 @@ MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) {
return
SUCCESS
;
}
MSRStatus
ShardWriter
::
set_page_s
ize
(
const
uint64_t
&
page_size
)
{
MSRStatus
ShardWriter
::
SetPageS
ize
(
const
uint64_t
&
page_size
)
{
// PageSize [32KB, 256MB]
if
(
page_size
<
kMinPageSize
||
page_size
>
kMaxPageSize
)
{
MS_LOG
(
ERROR
)
<<
"Page size should between 16KB and 256MB."
;
...
...
@@ -398,7 +398,7 @@ MSRStatus ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &ra
return
FAILED
;
}
json
schema
=
result
.
first
->
GetSchema
()[
"schema"
];
for
(
const
auto
&
field
:
result
.
first
->
get_blob_f
ields
())
{
for
(
const
auto
&
field
:
result
.
first
->
GetBlobF
ields
())
{
(
void
)
schema
.
erase
(
field
);
}
std
::
vector
<
json
>
sub_raw_data
=
rawdata_iter
->
second
;
...
...
@@ -456,7 +456,7 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t,
MS_LOG
(
DEBUG
)
<<
"Schema count is "
<<
schema_count_
;
// Determine if the number of schemas is the same
if
(
shard_header_
->
get_s
chemas
().
size
()
!=
schema_count_
)
{
if
(
shard_header_
->
GetS
chemas
().
size
()
!=
schema_count_
)
{
MS_LOG
(
ERROR
)
<<
"Data size is not equal with the schema size"
;
return
failed
;
}
...
...
@@ -475,9 +475,9 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t,
}
(
void
)
schema_ids
.
insert
(
rawdata_iter
->
first
);
}
const
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
&
schemas
=
shard_header_
->
get_s
chemas
();
const
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
&
schemas
=
shard_header_
->
GetS
chemas
();
if
(
std
::
any_of
(
schemas
.
begin
(),
schemas
.
end
(),
[
schema_ids
](
const
std
::
shared_ptr
<
Schema
>
&
schema
)
{
return
schema_ids
.
find
(
schema
->
get_schema_id
())
==
schema_ids
.
end
();
return
schema_ids
.
find
(
schema
->
GetSchemaID
())
==
schema_ids
.
end
();
}))
{
// There is not enough data which is not matching the number of schema
MS_LOG
(
ERROR
)
<<
"Input rawdata schema id do not match real schema id."
;
...
...
@@ -810,10 +810,10 @@ MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector
std
::
vector
<
std
::
pair
<
int
,
int
>>
&
rows_in_group
,
const
std
::
shared_ptr
<
Page
>
&
last_raw_page
,
const
std
::
shared_ptr
<
Page
>
&
last_blob_page
)
{
auto
n_byte_blob
=
last_blob_page
?
last_blob_page
->
get_page_s
ize
()
:
0
;
auto
n_byte_blob
=
last_blob_page
?
last_blob_page
->
GetPageS
ize
()
:
0
;
auto
last_raw_page_size
=
last_raw_page
?
last_raw_page
->
get_page_s
ize
()
:
0
;
auto
last_raw_offset
=
last_raw_page
?
last_raw_page
->
get_last_row_group_id
().
second
:
0
;
auto
last_raw_page_size
=
last_raw_page
?
last_raw_page
->
GetPageS
ize
()
:
0
;
auto
last_raw_offset
=
last_raw_page
?
last_raw_page
->
GetLastRowGroupID
().
second
:
0
;
auto
n_byte_raw
=
last_raw_page_size
-
last_raw_offset
;
int
page_start_row
=
start_row
;
...
...
@@ -849,8 +849,8 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std
if
(
blob_row
.
first
==
blob_row
.
second
)
return
SUCCESS
;
// Write disk
auto
page_id
=
last_blob_page
->
get_page_id
();
auto
bytes_page
=
last_blob_page
->
get_page_s
ize
();
auto
page_id
=
last_blob_page
->
GetPageID
();
auto
bytes_page
=
last_blob_page
->
GetPageS
ize
();
auto
&
io_seekp
=
file_streams_
[
shard_id
]
->
seekp
(
page_size_
*
page_id
+
header_size_
+
bytes_page
,
std
::
ios
::
beg
);
if
(
!
io_seekp
.
good
()
||
io_seekp
.
fail
()
||
io_seekp
.
bad
())
{
MS_LOG
(
ERROR
)
<<
"File seekp failed"
;
...
...
@@ -862,9 +862,9 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std
// Update last blob page
bytes_page
+=
std
::
accumulate
(
blob_data_size_
.
begin
()
+
blob_row
.
first
,
blob_data_size_
.
begin
()
+
blob_row
.
second
,
0
);
last_blob_page
->
set_page_s
ize
(
bytes_page
);
uint64_t
end_row
=
last_blob_page
->
get_end_row_id
()
+
blob_row
.
second
-
blob_row
.
first
;
last_blob_page
->
set_end_row_id
(
end_row
);
last_blob_page
->
SetPageS
ize
(
bytes_page
);
uint64_t
end_row
=
last_blob_page
->
GetEndRowID
()
+
blob_row
.
second
-
blob_row
.
first
;
last_blob_page
->
SetEndRowID
(
end_row
);
(
void
)
shard_header_
->
SetPage
(
last_blob_page
);
return
SUCCESS
;
}
...
...
@@ -873,8 +873,8 @@ MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector<std::v
const
std
::
vector
<
std
::
pair
<
int
,
int
>>
&
rows_in_group
,
const
std
::
shared_ptr
<
Page
>
&
last_blob_page
)
{
auto
page_id
=
shard_header_
->
GetLastPageId
(
shard_id
);
auto
page_type_id
=
last_blob_page
?
last_blob_page
->
get_page_type_id
()
:
-
1
;
auto
current_row
=
last_blob_page
?
last_blob_page
->
get_end_row_id
()
:
0
;
auto
page_type_id
=
last_blob_page
?
last_blob_page
->
GetPageTypeID
()
:
-
1
;
auto
current_row
=
last_blob_page
?
last_blob_page
->
GetEndRowID
()
:
0
;
// index(0) indicate appendBlobPage
for
(
uint32_t
i
=
1
;
i
<
rows_in_group
.
size
();
++
i
)
{
auto
blob_row
=
rows_in_group
[
i
];
...
...
@@ -905,15 +905,15 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::
std
::
shared_ptr
<
Page
>
&
last_raw_page
)
{
auto
blob_row
=
rows_in_group
[
0
];
if
(
blob_row
.
first
==
blob_row
.
second
)
return
SUCCESS
;
auto
last_raw_page_size
=
last_raw_page
?
last_raw_page
->
get_page_s
ize
()
:
0
;
auto
last_raw_page_size
=
last_raw_page
?
last_raw_page
->
GetPageS
ize
()
:
0
;
if
(
std
::
accumulate
(
raw_data_size_
.
begin
()
+
blob_row
.
first
,
raw_data_size_
.
begin
()
+
blob_row
.
second
,
0
)
+
last_raw_page_size
<=
page_size_
)
{
return
SUCCESS
;
}
auto
page_id
=
shard_header_
->
GetLastPageId
(
shard_id
);
auto
last_row_group_id_offset
=
last_raw_page
->
get_last_row_group_id
().
second
;
auto
last_raw_page_id
=
last_raw_page
->
get_page_id
();
auto
last_row_group_id_offset
=
last_raw_page
->
GetLastRowGroupID
().
second
;
auto
last_raw_page_id
=
last_raw_page
->
GetPageID
();
auto
shift_size
=
last_raw_page_size
-
last_row_group_id_offset
;
std
::
vector
<
uint8_t
>
buf
(
shift_size
);
...
...
@@ -956,10 +956,10 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::
(
void
)
shard_header_
->
SetPage
(
last_raw_page
);
// Refresh page info in header
int
row_group_id
=
last_raw_page
->
get_last_row_group_id
().
first
+
1
;
int
row_group_id
=
last_raw_page
->
GetLastRowGroupID
().
first
+
1
;
std
::
vector
<
std
::
pair
<
int
,
uint64_t
>>
row_group_ids
;
row_group_ids
.
emplace_back
(
row_group_id
,
0
);
int
page_type_id
=
last_raw_page
->
get_page_id
();
int
page_type_id
=
last_raw_page
->
GetPageID
();
auto
page
=
Page
(
++
page_id
,
shard_id
,
kPageTypeRaw
,
++
page_type_id
,
0
,
0
,
row_group_ids
,
shift_size
);
(
void
)
shard_header_
->
AddPage
(
std
::
make_shared
<
Page
>
(
page
));
...
...
@@ -971,7 +971,7 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::
MSRStatus
ShardWriter
::
WriteRawPage
(
const
int
&
shard_id
,
const
std
::
vector
<
std
::
pair
<
int
,
int
>>
&
rows_in_group
,
std
::
shared_ptr
<
Page
>
&
last_raw_page
,
const
std
::
vector
<
std
::
vector
<
uint8_t
>>
&
bin_raw_data
)
{
int
last_row_group_id
=
last_raw_page
?
last_raw_page
->
get_last_row_group_id
().
first
:
-
1
;
int
last_row_group_id
=
last_raw_page
?
last_raw_page
->
GetLastRowGroupID
().
first
:
-
1
;
for
(
uint32_t
i
=
0
;
i
<
rows_in_group
.
size
();
++
i
)
{
const
auto
&
blob_row
=
rows_in_group
[
i
];
if
(
blob_row
.
first
==
blob_row
.
second
)
continue
;
...
...
@@ -979,7 +979,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::
std
::
accumulate
(
raw_data_size_
.
begin
()
+
blob_row
.
first
,
raw_data_size_
.
begin
()
+
blob_row
.
second
,
0
);
if
(
!
last_raw_page
)
{
EmptyRawPage
(
shard_id
,
last_raw_page
);
}
else
if
(
last_raw_page
->
get_page_s
ize
()
+
raw_size
>
page_size_
)
{
}
else
if
(
last_raw_page
->
GetPageS
ize
()
+
raw_size
>
page_size_
)
{
(
void
)
shard_header_
->
SetPage
(
last_raw_page
);
EmptyRawPage
(
shard_id
,
last_raw_page
);
}
...
...
@@ -994,7 +994,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::
void
ShardWriter
::
EmptyRawPage
(
const
int
&
shard_id
,
std
::
shared_ptr
<
Page
>
&
last_raw_page
)
{
auto
row_group_ids
=
std
::
vector
<
std
::
pair
<
int
,
uint64_t
>>
();
auto
page_id
=
shard_header_
->
GetLastPageId
(
shard_id
);
auto
page_type_id
=
last_raw_page
?
last_raw_page
->
get_page_id
()
:
-
1
;
auto
page_type_id
=
last_raw_page
?
last_raw_page
->
GetPageID
()
:
-
1
;
auto
page
=
Page
(
++
page_id
,
shard_id
,
kPageTypeRaw
,
++
page_type_id
,
0
,
0
,
row_group_ids
,
0
);
(
void
)
shard_header_
->
AddPage
(
std
::
make_shared
<
Page
>
(
page
));
SetLastRawPage
(
shard_id
,
last_raw_page
);
...
...
@@ -1003,9 +1003,9 @@ void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_
MSRStatus
ShardWriter
::
AppendRawPage
(
const
int
&
shard_id
,
const
std
::
vector
<
std
::
pair
<
int
,
int
>>
&
rows_in_group
,
const
int
&
chunk_id
,
int
&
last_row_group_id
,
std
::
shared_ptr
<
Page
>
last_raw_page
,
const
std
::
vector
<
std
::
vector
<
uint8_t
>>
&
bin_raw_data
)
{
std
::
vector
<
std
::
pair
<
int
,
uint64_t
>>
row_group_ids
=
last_raw_page
->
get_row_group_i
ds
();
auto
last_raw_page_id
=
last_raw_page
->
get_page_id
();
auto
n_bytes
=
last_raw_page
->
get_page_s
ize
();
std
::
vector
<
std
::
pair
<
int
,
uint64_t
>>
row_group_ids
=
last_raw_page
->
GetRowGroupI
ds
();
auto
last_raw_page_id
=
last_raw_page
->
GetPageID
();
auto
n_bytes
=
last_raw_page
->
GetPageS
ize
();
// previous raw data page
auto
&
io_seekp
=
...
...
@@ -1022,8 +1022,8 @@ MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std:
(
void
)
FlushRawChunk
(
file_streams_
[
shard_id
],
rows_in_group
,
chunk_id
,
bin_raw_data
);
// Update previous raw data page
last_raw_page
->
set_page_s
ize
(
n_bytes
);
last_raw_page
->
set_row_group_i
ds
(
row_group_ids
);
last_raw_page
->
SetPageS
ize
(
n_bytes
);
last_raw_page
->
SetRowGroupI
ds
(
row_group_ids
);
(
void
)
shard_header_
->
SetPage
(
last_raw_page
);
return
SUCCESS
;
...
...
mindspore/ccsrc/mindrecord/meta/shard_category.cc
浏览文件 @
cf352d19
...
...
@@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem
num_categories_
(
num_categories
),
replacement_
(
replacement
)
{}
MSRStatus
ShardCategory
::
e
xecute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
MSRStatus
ShardCategory
::
E
xecute
(
ShardTask
&
tasks
)
{
return
SUCCESS
;
}
int64_t
ShardCategory
::
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
{
if
(
dataset_size
==
0
)
return
dataset_size
;
...
...
mindspore/ccsrc/mindrecord/meta/shard_header.cc
浏览文件 @
cf352d19
...
...
@@ -343,7 +343,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
std
::
string
ShardHeader
::
SerializeIndexFields
()
{
json
j
;
auto
fields
=
index_
->
get_f
ields
();
auto
fields
=
index_
->
GetF
ields
();
for
(
const
auto
&
field
:
fields
)
{
j
.
push_back
({{
"schema_id"
,
field
.
first
},
{
"index_field"
,
field
.
second
}});
}
...
...
@@ -365,7 +365,7 @@ std::vector<std::string> ShardHeader::SerializePage() {
std
::
string
ShardHeader
::
SerializeStatistics
()
{
json
j
;
for
(
const
auto
&
stats
:
statistics_
)
{
j
.
emplace_back
(
stats
->
get_s
tatistics
());
j
.
emplace_back
(
stats
->
GetS
tatistics
());
}
return
j
.
dump
();
}
...
...
@@ -398,8 +398,8 @@ MSRStatus ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
if
(
new_page
==
nullptr
)
{
return
FAILED
;
}
int
shard_id
=
new_page
->
get_shard_id
();
int
page_id
=
new_page
->
get_page_id
();
int
shard_id
=
new_page
->
GetShardID
();
int
page_id
=
new_page
->
GetPageID
();
if
(
shard_id
<
static_cast
<
int
>
(
pages_
.
size
())
&&
page_id
<
static_cast
<
int
>
(
pages_
[
shard_id
].
size
()))
{
pages_
[
shard_id
][
page_id
]
=
new_page
;
return
SUCCESS
;
...
...
@@ -412,8 +412,8 @@ MSRStatus ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
if
(
new_page
==
nullptr
)
{
return
FAILED
;
}
int
shard_id
=
new_page
->
get_shard_id
();
int
page_id
=
new_page
->
get_page_id
();
int
shard_id
=
new_page
->
GetShardID
();
int
page_id
=
new_page
->
GetPageID
();
if
(
shard_id
<
static_cast
<
int
>
(
pages_
.
size
())
&&
page_id
==
static_cast
<
int
>
(
pages_
[
shard_id
].
size
()))
{
pages_
[
shard_id
].
push_back
(
new_page
);
return
SUCCESS
;
...
...
@@ -435,8 +435,8 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag
}
int
last_page_id
=
-
1
;
for
(
uint64_t
i
=
pages_
[
shard_id
].
size
();
i
>=
1
;
i
--
)
{
if
(
pages_
[
shard_id
][
i
-
1
]
->
get_page_t
ype
()
==
page_type
)
{
last_page_id
=
pages_
[
shard_id
][
i
-
1
]
->
get_page_id
();
if
(
pages_
[
shard_id
][
i
-
1
]
->
GetPageT
ype
()
==
page_type
)
{
last_page_id
=
pages_
[
shard_id
][
i
-
1
]
->
GetPageID
();
return
last_page_id
;
}
}
...
...
@@ -451,7 +451,7 @@ const std::pair<MSRStatus, std::shared_ptr<Page>> ShardHeader::GetPageByGroupId(
}
for
(
uint64_t
i
=
pages_
[
shard_id
].
size
();
i
>=
1
;
i
--
)
{
auto
page
=
pages_
[
shard_id
][
i
-
1
];
if
(
page
->
get_page_type
()
==
kPageTypeBlob
&&
page
->
get_page_type_id
()
==
group_id
)
{
if
(
page
->
GetPageType
()
==
kPageTypeBlob
&&
page
->
GetPageTypeID
()
==
group_id
)
{
return
{
SUCCESS
,
page
};
}
}
...
...
@@ -470,10 +470,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
return
-
1
;
}
int64_t
schema_id
=
schema
->
get_schema_id
();
int64_t
schema_id
=
schema
->
GetSchemaID
();
if
(
schema_id
==
-
1
)
{
schema_id
=
schema_
.
size
();
schema
->
set_schema_id
(
schema_id
);
schema
->
SetSchemaID
(
schema_id
);
}
schema_
.
push_back
(
schema
);
return
schema_id
;
...
...
@@ -481,10 +481,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
void
ShardHeader
::
AddStatistic
(
std
::
shared_ptr
<
Statistics
>
statistic
)
{
if
(
statistic
)
{
int64_t
statistics_id
=
statistic
->
get_statistics_id
();
int64_t
statistics_id
=
statistic
->
GetStatisticsID
();
if
(
statistics_id
==
-
1
)
{
statistics_id
=
statistics_
.
size
();
statistic
->
set_statistics_id
(
statistics_id
);
statistic
->
SetStatisticsID
(
statistics_id
);
}
statistics_
.
push_back
(
statistic
);
}
...
...
@@ -527,13 +527,13 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
return
FAILED
;
}
if
(
get_s
chemas
().
empty
())
{
if
(
GetS
chemas
().
empty
())
{
MS_LOG
(
ERROR
)
<<
"No schema is set"
;
return
FAILED
;
}
for
(
const
auto
&
schemaPtr
:
schema_
)
{
auto
result
=
GetSchemaByID
(
schemaPtr
->
get_schema_id
());
auto
result
=
GetSchemaByID
(
schemaPtr
->
GetSchemaID
());
if
(
result
.
second
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Could not get schema by id."
;
return
FAILED
;
...
...
@@ -548,7 +548,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
// checkout and add fields for each schema
std
::
set
<
std
::
string
>
field_set
;
for
(
const
auto
&
item
:
index
->
get_f
ields
())
{
for
(
const
auto
&
item
:
index
->
GetF
ields
())
{
field_set
.
insert
(
item
.
second
);
}
for
(
const
auto
&
field
:
fields
)
{
...
...
@@ -564,7 +564,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
field_set
.
insert
(
field
);
// add field into index
index
.
get
()
->
AddIndexField
(
schemaPtr
->
get_schema_id
(),
field
);
index
.
get
()
->
AddIndexField
(
schemaPtr
->
GetSchemaID
(),
field
);
}
}
...
...
@@ -575,12 +575,12 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
MSRStatus
ShardHeader
::
GetAllSchemaID
(
std
::
set
<
uint64_t
>
&
bucket_count
)
{
// get all schema id
for
(
const
auto
&
schema
:
schema_
)
{
auto
bucket_it
=
bucket_count
.
find
(
schema
->
get_schema_id
());
auto
bucket_it
=
bucket_count
.
find
(
schema
->
GetSchemaID
());
if
(
bucket_it
!=
bucket_count
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Schema duplication"
;
return
FAILED
;
}
else
{
bucket_count
.
insert
(
schema
->
get_schema_id
());
bucket_count
.
insert
(
schema
->
GetSchemaID
());
}
}
return
SUCCESS
;
...
...
@@ -603,7 +603,7 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin
// check and add fields for each schema
std
::
set
<
std
::
pair
<
uint64_t
,
std
::
string
>>
field_set
;
for
(
const
auto
&
item
:
index
->
get_f
ields
())
{
for
(
const
auto
&
item
:
index
->
GetF
ields
())
{
field_set
.
insert
(
item
);
}
for
(
const
auto
&
field
:
fields
)
{
...
...
@@ -646,20 +646,20 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin
return
SUCCESS
;
}
std
::
string
ShardHeader
::
get_shard_address_by_id
(
int64_t
shard_id
)
{
std
::
string
ShardHeader
::
GetShardAddressByID
(
int64_t
shard_id
)
{
if
(
shard_id
>=
shard_addresses_
.
size
())
{
return
""
;
}
return
shard_addresses_
.
at
(
shard_id
);
}
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
ShardHeader
::
get_s
chemas
()
{
return
schema_
;
}
std
::
vector
<
std
::
shared_ptr
<
Schema
>>
ShardHeader
::
GetS
chemas
()
{
return
schema_
;
}
std
::
vector
<
std
::
shared_ptr
<
Statistics
>>
ShardHeader
::
get_s
tatistics
()
{
return
statistics_
;
}
std
::
vector
<
std
::
shared_ptr
<
Statistics
>>
ShardHeader
::
GetS
tatistics
()
{
return
statistics_
;
}
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
ShardHeader
::
get_fields
()
{
return
index_
->
get_f
ields
();
}
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
ShardHeader
::
GetFields
()
{
return
index_
->
GetF
ields
();
}
std
::
shared_ptr
<
Index
>
ShardHeader
::
get_i
ndex
()
{
return
index_
;
}
std
::
shared_ptr
<
Index
>
ShardHeader
::
GetI
ndex
()
{
return
index_
;
}
std
::
pair
<
std
::
shared_ptr
<
Schema
>
,
MSRStatus
>
ShardHeader
::
GetSchemaByID
(
int64_t
schema_id
)
{
int64_t
schemaSize
=
schema_
.
size
();
...
...
mindspore/ccsrc/mindrecord/meta/shard_index.cc
浏览文件 @
cf352d19
...
...
@@ -28,6 +28,6 @@ void Index::AddIndexField(const int64_t &schemaId, const std::string &field) {
}
// Get attribute list
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
Index
::
get_f
ields
()
{
return
fields_
;
}
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
Index
::
GetF
ields
()
{
return
fields_
;
}
}
// namespace mindrecord
}
// namespace mindspore
mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc
浏览文件 @
cf352d19
...
...
@@ -34,7 +34,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem
shuffle_op_
=
std
::
make_shared
<
ShardShuffle
>
(
seed
,
kShuffleSample
);
// do shuffle and replacement
}
MSRStatus
ShardPkSample
::
suf_e
xecute
(
ShardTask
&
tasks
)
{
MSRStatus
ShardPkSample
::
SufE
xecute
(
ShardTask
&
tasks
)
{
if
(
shuffle_
==
true
)
{
if
(
SUCCESS
!=
(
*
shuffle_op_
)(
tasks
))
{
return
FAILED
;
...
...
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
浏览文件 @
cf352d19
...
...
@@ -74,14 +74,14 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
return
-
1
;
}
const
std
::
pair
<
int
,
int
>
ShardSample
::
get_p
artitions
()
const
{
const
std
::
pair
<
int
,
int
>
ShardSample
::
GetP
artitions
()
const
{
if
(
numerator_
==
1
&&
denominator_
>
1
)
{
return
std
::
pair
<
int
,
int
>
(
denominator_
,
partition_id_
);
}
return
std
::
pair
<
int
,
int
>
(
-
1
,
-
1
);
}
MSRStatus
ShardSample
::
e
xecute
(
ShardTask
&
tasks
)
{
MSRStatus
ShardSample
::
E
xecute
(
ShardTask
&
tasks
)
{
int
no_of_categories
=
static_cast
<
int
>
(
tasks
.
categories
);
int
total_no
=
static_cast
<
int
>
(
tasks
.
Size
());
...
...
@@ -114,11 +114,11 @@ MSRStatus ShardSample::execute(ShardTask &tasks) {
if
(
sampler_type_
==
kSubsetRandomSampler
)
{
for
(
int
i
=
0
;
i
<
indices_
.
size
();
++
i
)
{
int
index
=
((
indices_
[
i
]
%
total_no
)
+
total_no
)
%
total_no
;
new_tasks
.
InsertTask
(
tasks
.
get_task_by_id
(
index
));
// different mod result between c and python
new_tasks
.
InsertTask
(
tasks
.
GetTaskByID
(
index
));
// different mod result between c and python
}
}
else
{
for
(
int
i
=
partition_id_
*
taking
;
i
<
(
partition_id_
+
1
)
*
taking
;
i
++
)
{
new_tasks
.
InsertTask
(
tasks
.
get_task_by_id
(
i
%
total_no
));
// rounding up. if overflow, go back to start
new_tasks
.
InsertTask
(
tasks
.
GetTaskByID
(
i
%
total_no
));
// rounding up. if overflow, go back to start
}
}
std
::
swap
(
tasks
,
new_tasks
);
...
...
@@ -129,14 +129,14 @@ MSRStatus ShardSample::execute(ShardTask &tasks) {
}
total_no
=
static_cast
<
int
>
(
tasks
.
permutation_
.
size
());
for
(
size_t
i
=
partition_id_
*
taking
;
i
<
(
partition_id_
+
1
)
*
taking
;
i
++
)
{
new_tasks
.
InsertTask
(
tasks
.
get_task_by_id
(
tasks
.
permutation_
[
i
%
total_no
]));
new_tasks
.
InsertTask
(
tasks
.
GetTaskByID
(
tasks
.
permutation_
[
i
%
total_no
]));
}
std
::
swap
(
tasks
,
new_tasks
);
}
return
SUCCESS
;
}
MSRStatus
ShardSample
::
suf_e
xecute
(
ShardTask
&
tasks
)
{
MSRStatus
ShardSample
::
SufE
xecute
(
ShardTask
&
tasks
)
{
if
(
sampler_type_
==
kSubsetRandomSampler
)
{
if
(
SUCCESS
!=
(
*
shuffle_op_
)(
tasks
))
{
return
FAILED
;
...
...
mindspore/ccsrc/mindrecord/meta/shard_schema.cc
浏览文件 @
cf352d19
...
...
@@ -44,7 +44,7 @@ std::shared_ptr<Schema> Schema::Build(std::string desc, pybind11::handle schema)
return
Build
(
std
::
move
(
desc
),
schema_json
);
}
std
::
string
Schema
::
get_d
esc
()
const
{
return
desc_
;
}
std
::
string
Schema
::
GetD
esc
()
const
{
return
desc_
;
}
json
Schema
::
GetSchema
()
const
{
json
str_schema
;
...
...
@@ -60,11 +60,11 @@ pybind11::object Schema::GetSchemaForPython() const {
return
schema_py
;
}
void
Schema
::
set_schema_id
(
int64_t
id
)
{
schema_id_
=
id
;
}
void
Schema
::
SetSchemaID
(
int64_t
id
)
{
schema_id_
=
id
;
}
int64_t
Schema
::
get_schema_id
()
const
{
return
schema_id_
;
}
int64_t
Schema
::
GetSchemaID
()
const
{
return
schema_id_
;
}
std
::
vector
<
std
::
string
>
Schema
::
get_blob_f
ields
()
const
{
return
blob_fields_
;
}
std
::
vector
<
std
::
string
>
Schema
::
GetBlobF
ields
()
const
{
return
blob_fields_
;
}
std
::
vector
<
std
::
string
>
Schema
::
PopulateBlobFields
(
json
schema
)
{
std
::
vector
<
std
::
string
>
blob_fields
;
...
...
@@ -155,7 +155,7 @@ bool Schema::Validate(json schema) {
}
bool
Schema
::
operator
==
(
const
mindrecord
::
Schema
&
b
)
const
{
if
(
this
->
get_desc
()
!=
b
.
get_d
esc
()
||
this
->
GetSchema
()
!=
b
.
GetSchema
())
{
if
(
this
->
GetDesc
()
!=
b
.
GetD
esc
()
||
this
->
GetSchema
()
!=
b
.
GetSchema
())
{
return
false
;
}
return
true
;
...
...
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
浏览文件 @
cf352d19
...
...
@@ -23,7 +23,7 @@ namespace mindrecord {
ShardShuffle
::
ShardShuffle
(
uint32_t
seed
,
ShuffleType
shuffle_type
)
:
shuffle_seed_
(
seed
),
shuffle_type_
(
shuffle_type
)
{}
MSRStatus
ShardShuffle
::
e
xecute
(
ShardTask
&
tasks
)
{
MSRStatus
ShardShuffle
::
E
xecute
(
ShardTask
&
tasks
)
{
if
(
tasks
.
categories
<
1
)
{
return
FAILED
;
}
...
...
mindspore/ccsrc/mindrecord/meta/shard_statistics.cc
浏览文件 @
cf352d19
...
...
@@ -48,9 +48,9 @@ std::shared_ptr<Statistics> Statistics::Build(std::string desc, pybind11::handle
return
std
::
make_shared
<
Statistics
>
(
object_statistics
);
}
std
::
string
Statistics
::
get_d
esc
()
const
{
return
desc_
;
}
std
::
string
Statistics
::
GetD
esc
()
const
{
return
desc_
;
}
json
Statistics
::
get_s
tatistics
()
const
{
json
Statistics
::
GetS
tatistics
()
const
{
json
str_statistics
;
str_statistics
[
"desc"
]
=
desc_
;
str_statistics
[
"statistics"
]
=
statistics_
;
...
...
@@ -58,13 +58,13 @@ json Statistics::get_statistics() const {
}
pybind11
::
object
Statistics
::
GetStatisticsForPython
()
const
{
json
str_statistics
=
Statistics
::
get_s
tatistics
();
json
str_statistics
=
Statistics
::
GetS
tatistics
();
return
nlohmann
::
detail
::
FromJsonImpl
(
str_statistics
);
}
void
Statistics
::
set_statistics_id
(
int64_t
id
)
{
statistics_id_
=
id
;
}
void
Statistics
::
SetStatisticsID
(
int64_t
id
)
{
statistics_id_
=
id
;
}
int64_t
Statistics
::
get_statistics_id
()
const
{
return
statistics_id_
;
}
int64_t
Statistics
::
GetStatisticsID
()
const
{
return
statistics_id_
;
}
bool
Statistics
::
Validate
(
const
json
&
statistics
)
{
if
(
statistics
.
size
()
!=
kInt1
)
{
...
...
@@ -103,7 +103,7 @@ bool Statistics::LevelRecursive(json level) {
}
bool
Statistics
::
operator
==
(
const
Statistics
&
b
)
const
{
if
(
this
->
get_statistics
()
!=
b
.
get_s
tatistics
())
{
if
(
this
->
GetStatistics
()
!=
b
.
GetS
tatistics
())
{
return
false
;
}
return
true
;
...
...
mindspore/ccsrc/mindrecord/meta/shard_task.cc
浏览文件 @
cf352d19
...
...
@@ -59,12 +59,12 @@ uint32_t ShardTask::SizeOfRows() const {
return
nRows
;
}
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
ShardTask
::
get_task_by_id
(
size_t
id
)
{
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
ShardTask
::
GetTaskByID
(
size_t
id
)
{
MS_ASSERT
(
id
<
task_list_
.
size
());
return
task_list_
[
id
];
}
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
ShardTask
::
get_random_t
ask
()
{
std
::
tuple
<
std
::
tuple
<
int
,
int
>
,
std
::
vector
<
uint64_t
>
,
json
>
&
ShardTask
::
GetRandomT
ask
()
{
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
uniform_int_distribution
<>
dis
(
0
,
task_list_
.
size
()
-
1
);
...
...
@@ -82,7 +82,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
}
for
(
uint32_t
task_no
=
0
;
task_no
<
minTasks
;
task_no
++
)
{
for
(
uint32_t
i
=
0
;
i
<
total_categories
;
i
++
)
{
res
.
InsertTask
(
std
::
move
(
category_tasks
[
i
].
get_task_by_id
(
static_cast
<
int
>
(
task_no
))));
res
.
InsertTask
(
std
::
move
(
category_tasks
[
i
].
GetTaskByID
(
static_cast
<
int
>
(
task_no
))));
}
}
}
else
{
...
...
@@ -95,7 +95,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
}
for
(
uint32_t
i
=
0
;
i
<
total_categories
;
i
++
)
{
for
(
uint32_t
j
=
0
;
j
<
maxTasks
;
j
++
)
{
res
.
InsertTask
(
category_tasks
[
i
].
get_random_t
ask
());
res
.
InsertTask
(
category_tasks
[
i
].
GetRandomT
ask
());
}
}
}
...
...
tests/ut/cpp/mindrecord/ut_shard.cc
浏览文件 @
cf352d19
...
...
@@ -52,7 +52,7 @@ TEST_F(TestShard, TestShardSchemaPart) {
std
::
shared_ptr
<
Schema
>
schema
=
Schema
::
Build
(
desc
,
j
);
ASSERT_TRUE
(
schema
!=
nullptr
);
MS_LOG
(
INFO
)
<<
"schema description: "
<<
schema
->
get_d
esc
()
<<
", schema: "
<<
MS_LOG
(
INFO
)
<<
"schema description: "
<<
schema
->
GetD
esc
()
<<
", schema: "
<<
common
::
SafeCStr
(
schema
->
GetSchema
().
dump
());
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
);
...
...
@@ -71,8 +71,8 @@ TEST_F(TestShard, TestStatisticPart) {
nlohmann
::
json
statistic_json
=
json
::
parse
(
kStatistics
[
2
]);
std
::
shared_ptr
<
Statistics
>
statistics
=
Statistics
::
Build
(
desc
,
statistic_json
);
ASSERT_TRUE
(
statistics
!=
nullptr
);
MS_LOG
(
INFO
)
<<
"test get_desc(), result: "
<<
statistics
->
get_d
esc
();
MS_LOG
(
INFO
)
<<
"test get_statistics, result: "
<<
statistics
->
get_s
tatistics
().
dump
();
MS_LOG
(
INFO
)
<<
"test get_desc(), result: "
<<
statistics
->
GetD
esc
();
MS_LOG
(
INFO
)
<<
"test get_statistics, result: "
<<
statistics
->
GetS
tatistics
().
dump
();
std
::
string
desc2
=
"axis"
;
nlohmann
::
json
statistic_json2
=
R"({})"
;
...
...
@@ -111,13 +111,13 @@ TEST_F(TestShard, TestShardHeaderPart) {
ASSERT_EQ
(
res
,
0
);
header_data
.
AddStatistic
(
statistics1
);
std
::
vector
<
Schema
>
re_schemas
;
for
(
auto
&
schema_ptr
:
header_data
.
get_s
chemas
())
{
for
(
auto
&
schema_ptr
:
header_data
.
GetS
chemas
())
{
re_schemas
.
push_back
(
*
schema_ptr
);
}
ASSERT_EQ
(
re_schemas
,
validate_schema
);
std
::
vector
<
Statistics
>
re_statistics
;
for
(
auto
&
statistic
:
header_data
.
get_s
tatistics
())
{
for
(
auto
&
statistic
:
header_data
.
GetS
tatistics
())
{
re_statistics
.
push_back
(
*
statistic
);
}
ASSERT_EQ
(
re_statistics
,
validate_statistics
);
...
...
@@ -129,7 +129,7 @@ TEST_F(TestShard, TestShardHeaderPart) {
std
::
pair
<
uint64_t
,
std
::
string
>
pair1
(
0
,
"name"
);
fields
.
push_back
(
pair1
);
ASSERT_TRUE
(
header_data
.
AddIndexFields
(
fields
)
==
SUCCESS
);
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
resFields
=
header_data
.
get_f
ields
();
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
resFields
=
header_data
.
GetF
ields
();
ASSERT_EQ
(
resFields
,
fields
);
}
...
...
tests/ut/cpp/mindrecord/ut_shard_header_test.cc
浏览文件 @
cf352d19
...
...
@@ -70,7 +70,7 @@ TEST_F(TestShardHeader, AddIndexFields) {
int
schema_id1
=
header_data
.
AddSchema
(
schema1
);
int
schema_id2
=
header_data
.
AddSchema
(
schema2
);
ASSERT_EQ
(
schema_id2
,
-
1
);
ASSERT_EQ
(
header_data
.
get_s
chemas
().
size
(),
1
);
ASSERT_EQ
(
header_data
.
GetS
chemas
().
size
(),
1
);
// check out fields
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields
;
...
...
@@ -81,35 +81,35 @@ TEST_F(TestShardHeader, AddIndexFields) {
fields
.
push_back
(
index_field2
);
MSRStatus
res
=
header_data
.
AddIndexFields
(
fields
);
ASSERT_EQ
(
res
,
SUCCESS
);
ASSERT_EQ
(
header_data
.
get_f
ields
().
size
(),
2
);
ASSERT_EQ
(
header_data
.
GetF
ields
().
size
(),
2
);
fields
.
clear
();
std
::
pair
<
uint64_t
,
std
::
string
>
index_field3
(
schema_id1
,
"name"
);
fields
.
push_back
(
index_field3
);
res
=
header_data
.
AddIndexFields
(
fields
);
ASSERT_EQ
(
res
,
FAILED
);
ASSERT_EQ
(
header_data
.
get_f
ields
().
size
(),
2
);
ASSERT_EQ
(
header_data
.
GetF
ields
().
size
(),
2
);
fields
.
clear
();
std
::
pair
<
uint64_t
,
std
::
string
>
index_field4
(
schema_id1
,
"names"
);
fields
.
push_back
(
index_field4
);
res
=
header_data
.
AddIndexFields
(
fields
);
ASSERT_EQ
(
res
,
FAILED
);
ASSERT_EQ
(
header_data
.
get_f
ields
().
size
(),
2
);
ASSERT_EQ
(
header_data
.
GetF
ields
().
size
(),
2
);
fields
.
clear
();
std
::
pair
<
uint64_t
,
std
::
string
>
index_field5
(
schema_id1
+
1
,
"name"
);
fields
.
push_back
(
index_field5
);
res
=
header_data
.
AddIndexFields
(
fields
);
ASSERT_EQ
(
res
,
FAILED
);
ASSERT_EQ
(
header_data
.
get_f
ields
().
size
(),
2
);
ASSERT_EQ
(
header_data
.
GetF
ields
().
size
(),
2
);
fields
.
clear
();
std
::
pair
<
uint64_t
,
std
::
string
>
index_field6
(
schema_id1
,
"label"
);
fields
.
push_back
(
index_field6
);
res
=
header_data
.
AddIndexFields
(
fields
);
ASSERT_EQ
(
res
,
FAILED
);
ASSERT_EQ
(
header_data
.
get_f
ields
().
size
(),
2
);
ASSERT_EQ
(
header_data
.
GetF
ields
().
size
(),
2
);
std
::
string
desc_new
=
"this is a test1"
;
json
schemaContent_new
=
R"({"name": {"type": "string"},
...
...
@@ -121,7 +121,7 @@ TEST_F(TestShardHeader, AddIndexFields) {
mindrecord
::
ShardHeader
header_data_new
;
header_data_new
.
AddSchema
(
schema_new
);
ASSERT_EQ
(
header_data_new
.
get_s
chemas
().
size
(),
1
);
ASSERT_EQ
(
header_data_new
.
GetS
chemas
().
size
(),
1
);
// test add fields
std
::
vector
<
std
::
string
>
single_fields
;
...
...
@@ -131,25 +131,25 @@ TEST_F(TestShardHeader, AddIndexFields) {
single_fields
.
push_back
(
"box"
);
res
=
header_data_new
.
AddIndexFields
(
single_fields
);
ASSERT_EQ
(
res
,
FAILED
);
ASSERT_EQ
(
header_data_new
.
get_f
ields
().
size
(),
1
);
ASSERT_EQ
(
header_data_new
.
GetF
ields
().
size
(),
1
);
single_fields
.
push_back
(
"name"
);
single_fields
.
push_back
(
"box"
);
res
=
header_data_new
.
AddIndexFields
(
single_fields
);
ASSERT_EQ
(
res
,
FAILED
);
ASSERT_EQ
(
header_data_new
.
get_f
ields
().
size
(),
1
);
ASSERT_EQ
(
header_data_new
.
GetF
ields
().
size
(),
1
);
single_fields
.
clear
();
single_fields
.
push_back
(
"names"
);
res
=
header_data_new
.
AddIndexFields
(
single_fields
);
ASSERT_EQ
(
res
,
FAILED
);
ASSERT_EQ
(
header_data_new
.
get_f
ields
().
size
(),
1
);
ASSERT_EQ
(
header_data_new
.
GetF
ields
().
size
(),
1
);
single_fields
.
clear
();
single_fields
.
push_back
(
"box"
);
res
=
header_data_new
.
AddIndexFields
(
single_fields
);
ASSERT_EQ
(
res
,
SUCCESS
);
ASSERT_EQ
(
header_data_new
.
get_f
ields
().
size
(),
2
);
ASSERT_EQ
(
header_data_new
.
GetF
ields
().
size
(),
2
);
}
}
// namespace mindrecord
}
// namespace mindspore
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
浏览文件 @
cf352d19
...
...
@@ -139,7 +139,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
const
int
kPar
=
2
;
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kNum
,
kDen
,
kPar
));
auto
partitions
=
std
::
dynamic_pointer_cast
<
ShardSample
>
(
ops
[
0
])
->
get_p
artitions
();
auto
partitions
=
std
::
dynamic_pointer_cast
<
ShardSample
>
(
ops
[
0
])
->
GetP
artitions
();
ASSERT_TRUE
(
partitions
.
first
==
4
);
ASSERT_TRUE
(
partitions
.
second
==
2
);
...
...
tests/ut/cpp/mindrecord/ut_shard_page_test.cc
浏览文件 @
cf352d19
...
...
@@ -57,15 +57,15 @@ TEST_F(TestShardPage, TestBasic) {
Page
page
=
Page
(
kGoldenPageId
,
kGoldenShardId
,
kGoldenType
,
kGoldenTypeId
,
kGoldenStart
,
kGoldenEnd
,
golden_row_group
,
kGoldenSize
);
EXPECT_EQ
(
kGoldenPageId
,
page
.
get_page_id
());
EXPECT_EQ
(
kGoldenShardId
,
page
.
get_shard_id
());
EXPECT_EQ
(
kGoldenTypeId
,
page
.
get_page_type_id
());
ASSERT_TRUE
(
kGoldenType
==
page
.
get_page_t
ype
());
EXPECT_EQ
(
kGoldenSize
,
page
.
get_page_s
ize
());
EXPECT_EQ
(
kGoldenStart
,
page
.
get_start_row_id
());
EXPECT_EQ
(
kGoldenEnd
,
page
.
get_end_row_id
());
ASSERT_TRUE
(
std
::
make_pair
(
4
,
kOffset
)
==
page
.
get_last_row_group_id
());
ASSERT_TRUE
(
golden_row_group
==
page
.
get_row_group_i
ds
());
EXPECT_EQ
(
kGoldenPageId
,
page
.
GetPageID
());
EXPECT_EQ
(
kGoldenShardId
,
page
.
GetShardID
());
EXPECT_EQ
(
kGoldenTypeId
,
page
.
GetPageTypeID
());
ASSERT_TRUE
(
kGoldenType
==
page
.
GetPageT
ype
());
EXPECT_EQ
(
kGoldenSize
,
page
.
GetPageS
ize
());
EXPECT_EQ
(
kGoldenStart
,
page
.
GetStartRowID
());
EXPECT_EQ
(
kGoldenEnd
,
page
.
GetEndRowID
());
ASSERT_TRUE
(
std
::
make_pair
(
4
,
kOffset
)
==
page
.
GetLastRowGroupID
());
ASSERT_TRUE
(
golden_row_group
==
page
.
GetRowGroupI
ds
());
}
TEST_F
(
TestShardPage
,
TestSetter
)
{
...
...
@@ -86,43 +86,43 @@ TEST_F(TestShardPage, TestSetter) {
Page
page
=
Page
(
kGoldenPageId
,
kGoldenShardId
,
kGoldenType
,
kGoldenTypeId
,
kGoldenStart
,
kGoldenEnd
,
golden_row_group
,
kGoldenSize
);
EXPECT_EQ
(
kGoldenPageId
,
page
.
get_page_id
());
EXPECT_EQ
(
kGoldenShardId
,
page
.
get_shard_id
());
EXPECT_EQ
(
kGoldenTypeId
,
page
.
get_page_type_id
());
ASSERT_TRUE
(
kGoldenType
==
page
.
get_page_t
ype
());
EXPECT_EQ
(
kGoldenSize
,
page
.
get_page_s
ize
());
EXPECT_EQ
(
kGoldenStart
,
page
.
get_start_row_id
());
EXPECT_EQ
(
kGoldenEnd
,
page
.
get_end_row_id
());
ASSERT_TRUE
(
std
::
make_pair
(
4
,
kOffset1
)
==
page
.
get_last_row_group_id
());
ASSERT_TRUE
(
golden_row_group
==
page
.
get_row_group_i
ds
());
EXPECT_EQ
(
kGoldenPageId
,
page
.
GetPageID
());
EXPECT_EQ
(
kGoldenShardId
,
page
.
GetShardID
());
EXPECT_EQ
(
kGoldenTypeId
,
page
.
GetPageTypeID
());
ASSERT_TRUE
(
kGoldenType
==
page
.
GetPageT
ype
());
EXPECT_EQ
(
kGoldenSize
,
page
.
GetPageS
ize
());
EXPECT_EQ
(
kGoldenStart
,
page
.
GetStartRowID
());
EXPECT_EQ
(
kGoldenEnd
,
page
.
GetEndRowID
());
ASSERT_TRUE
(
std
::
make_pair
(
4
,
kOffset1
)
==
page
.
GetLastRowGroupID
());
ASSERT_TRUE
(
golden_row_group
==
page
.
GetRowGroupI
ds
());
const
int
kNewEnd
=
33
;
const
int
kNewSize
=
300
;
std
::
vector
<
std
::
pair
<
int
,
uint64_t
>>
new_row_group
=
{{
0
,
100
},
{
100
,
200
},
{
200
,
3000
}};
page
.
set_end_row_id
(
kNewEnd
);
page
.
set_page_s
ize
(
kNewSize
);
page
.
set_row_group_i
ds
(
new_row_group
);
EXPECT_EQ
(
kGoldenPageId
,
page
.
get_page_id
());
EXPECT_EQ
(
kGoldenShardId
,
page
.
get_shard_id
());
EXPECT_EQ
(
kGoldenTypeId
,
page
.
get_page_type_id
());
ASSERT_TRUE
(
kGoldenType
==
page
.
get_page_t
ype
());
EXPECT_EQ
(
kNewSize
,
page
.
get_page_s
ize
());
EXPECT_EQ
(
kGoldenStart
,
page
.
get_start_row_id
());
EXPECT_EQ
(
kNewEnd
,
page
.
get_end_row_id
());
ASSERT_TRUE
(
std
::
make_pair
(
200
,
kOffset2
)
==
page
.
get_last_row_group_id
());
ASSERT_TRUE
(
new_row_group
==
page
.
get_row_group_i
ds
());
page
.
SetEndRowID
(
kNewEnd
);
page
.
SetPageS
ize
(
kNewSize
);
page
.
SetRowGroupI
ds
(
new_row_group
);
EXPECT_EQ
(
kGoldenPageId
,
page
.
GetPageID
());
EXPECT_EQ
(
kGoldenShardId
,
page
.
GetShardID
());
EXPECT_EQ
(
kGoldenTypeId
,
page
.
GetPageTypeID
());
ASSERT_TRUE
(
kGoldenType
==
page
.
GetPageT
ype
());
EXPECT_EQ
(
kNewSize
,
page
.
GetPageS
ize
());
EXPECT_EQ
(
kGoldenStart
,
page
.
GetStartRowID
());
EXPECT_EQ
(
kNewEnd
,
page
.
GetEndRowID
());
ASSERT_TRUE
(
std
::
make_pair
(
200
,
kOffset2
)
==
page
.
GetLastRowGroupID
());
ASSERT_TRUE
(
new_row_group
==
page
.
GetRowGroupI
ds
());
page
.
DeleteLastGroupId
();
EXPECT_EQ
(
kGoldenPageId
,
page
.
get_page_id
());
EXPECT_EQ
(
kGoldenShardId
,
page
.
get_shard_id
());
EXPECT_EQ
(
kGoldenTypeId
,
page
.
get_page_type_id
());
ASSERT_TRUE
(
kGoldenType
==
page
.
get_page_t
ype
());
EXPECT_EQ
(
3000
,
page
.
get_page_s
ize
());
EXPECT_EQ
(
kGoldenStart
,
page
.
get_start_row_id
());
EXPECT_EQ
(
kNewEnd
,
page
.
get_end_row_id
());
ASSERT_TRUE
(
std
::
make_pair
(
100
,
kOffset3
)
==
page
.
get_last_row_group_id
());
EXPECT_EQ
(
kGoldenPageId
,
page
.
GetPageID
());
EXPECT_EQ
(
kGoldenShardId
,
page
.
GetShardID
());
EXPECT_EQ
(
kGoldenTypeId
,
page
.
GetPageTypeID
());
ASSERT_TRUE
(
kGoldenType
==
page
.
GetPageT
ype
());
EXPECT_EQ
(
3000
,
page
.
GetPageS
ize
());
EXPECT_EQ
(
kGoldenStart
,
page
.
GetStartRowID
());
EXPECT_EQ
(
kNewEnd
,
page
.
GetEndRowID
());
ASSERT_TRUE
(
std
::
make_pair
(
100
,
kOffset3
)
==
page
.
GetLastRowGroupID
());
new_row_group
.
pop_back
();
ASSERT_TRUE
(
new_row_group
==
page
.
get_row_group_i
ds
());
ASSERT_TRUE
(
new_row_group
==
page
.
GetRowGroupI
ds
());
}
TEST_F
(
TestShardPage
,
TestJson
)
{
...
...
tests/ut/cpp/mindrecord/ut_shard_schema_test.cc
浏览文件 @
cf352d19
...
...
@@ -107,15 +107,15 @@ TEST_F(TestShardSchema, TestFunction) {
std
::
shared_ptr
<
Schema
>
schema
=
Schema
::
Build
(
desc
,
schema_content
);
ASSERT_NE
(
schema
,
nullptr
);
ASSERT_EQ
(
schema
->
get_d
esc
(),
desc
);
ASSERT_EQ
(
schema
->
GetD
esc
(),
desc
);
json
schema_json
=
schema
->
GetSchema
();
ASSERT_EQ
(
schema_json
[
"desc"
],
desc
);
ASSERT_EQ
(
schema_json
[
"schema"
],
schema_content
);
ASSERT_EQ
(
schema
->
get_schema_id
(),
-
1
);
schema
->
set_schema_id
(
2
);
ASSERT_EQ
(
schema
->
get_schema_id
(),
2
);
ASSERT_EQ
(
schema
->
GetSchemaID
(),
-
1
);
schema
->
SetSchemaID
(
2
);
ASSERT_EQ
(
schema
->
GetSchemaID
(),
2
);
}
TEST_F
(
TestStatistics
,
StatisticPart
)
{
...
...
@@ -137,8 +137,8 @@ TEST_F(TestStatistics, StatisticPart) {
ASSERT_NE
(
statistics
,
nullptr
);
MS_LOG
(
INFO
)
<<
"test
get_desc(), result: "
<<
statistics
->
get_d
esc
();
MS_LOG
(
INFO
)
<<
"test
get_statistics, result: "
<<
statistics
->
get_s
tatistics
().
dump
();
MS_LOG
(
INFO
)
<<
"test
GetDesc(), result: "
<<
statistics
->
GetD
esc
();
MS_LOG
(
INFO
)
<<
"test
GetStatistics, result: "
<<
statistics
->
GetS
tatistics
().
dump
();
statistic_json
[
"test"
]
=
"test"
;
statistics
=
Statistics
::
Build
(
desc
,
statistic_json
);
...
...
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
浏览文件 @
cf352d19
...
...
@@ -194,8 +194,8 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) {
fw
.
Open
(
file_names
);
uint64_t
header_size
=
1
<<
14
;
uint64_t
page_size
=
1
<<
15
;
fw
.
set_header_s
ize
(
header_size
);
fw
.
set_page_s
ize
(
page_size
);
fw
.
SetHeaderS
ize
(
header_size
);
fw
.
SetPageS
ize
(
page_size
);
// set shardHeader
fw
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
...
...
@@ -331,8 +331,8 @@ TEST_F(TestShardWriter, TestShardWriterTrial) {
fw
.
Open
(
file_names
);
uint64_t
header_size
=
1
<<
14
;
uint64_t
page_size
=
1
<<
17
;
fw
.
set_header_s
ize
(
header_size
);
fw
.
set_page_s
ize
(
page_size
);
fw
.
SetHeaderS
ize
(
header_size
);
fw
.
SetPageS
ize
(
page_size
);
// set shardHeader
fw
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
...
...
@@ -466,8 +466,8 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) {
fw
.
Open
(
file_names
);
uint64_t
header_size
=
1
<<
14
;
uint64_t
page_size
=
1
<<
17
;
fw
.
set_header_s
ize
(
header_size
);
fw
.
set_page_s
ize
(
page_size
);
fw
.
SetHeaderS
ize
(
header_size
);
fw
.
SetPageS
ize
(
page_size
);
// set shardHeader
fw
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
...
...
@@ -567,8 +567,8 @@ TEST_F(TestShardWriter, DataCheck) {
fw
.
Open
(
file_names
);
uint64_t
header_size
=
1
<<
14
;
uint64_t
page_size
=
1
<<
17
;
fw
.
set_header_s
ize
(
header_size
);
fw
.
set_page_s
ize
(
page_size
);
fw
.
SetHeaderS
ize
(
header_size
);
fw
.
SetPageS
ize
(
page_size
);
// set shardHeader
fw
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
...
...
@@ -668,8 +668,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) {
fw
.
Open
(
file_names
);
uint64_t
header_size
=
1
<<
14
;
uint64_t
page_size
=
1
<<
17
;
fw
.
set_header_s
ize
(
header_size
);
fw
.
set_page_s
ize
(
page_size
);
fw
.
SetHeaderS
ize
(
header_size
);
fw
.
SetPageS
ize
(
page_size
);
// set shardHeader
fw
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录