Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a694fb33
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a694fb33
编写于
5月 11, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(serialization): implement the new serialization format
GitOrigin-RevId: 00f87f7ccdae7d313d8a97108c3ae712b1c27da3
上级
ca4a5da0
变更
10
展开全部
显示空白变更内容
内联
并排
Showing
10 changed file
with
1210 addition
and
52 deletion
+1210
-52
src/serialization/impl/batched_device_value_loader.cpp
src/serialization/impl/batched_device_value_loader.cpp
+1
-2
src/serialization/impl/serializer.cpp
src/serialization/impl/serializer.cpp
+17
-0
src/serialization/impl/serializer_oss.cpp
src/serialization/impl/serializer_oss.cpp
+2
-49
src/serialization/impl/serializer_oss_common.cpp
src/serialization/impl/serializer_oss_common.cpp
+39
-0
src/serialization/impl/serializer_oss_common.h
src/serialization/impl/serializer_oss_common.h
+32
-0
src/serialization/impl/serializer_oss_v2.cpp
src/serialization/impl/serializer_oss_v2.cpp
+847
-0
src/serialization/include/megbrain/serialization/batched_device_value_loader.h
...lude/megbrain/serialization/batched_device_value_loader.h
+0
-0
src/serialization/include/megbrain/serialization/dump_format.h
...erialization/include/megbrain/serialization/dump_format.h
+1
-0
src/serialization/include/megbrain/serialization/opr_load_dump.h
...ialization/include/megbrain/serialization/opr_load_dump.h
+14
-1
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
...zation/include/megbrain/serialization/oss_opr_load_dump.h
+257
-0
未找到文件。
src/serialization/impl/batched_device_value_loader.cpp
浏览文件 @
a694fb33
#include "batched_device_value_loader.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/utils/arith_helper.h"
namespace
mgb
{
...
...
src/serialization/impl/serializer.cpp
浏览文件 @
a694fb33
...
...
@@ -57,7 +57,11 @@ GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
}
std
::
unique_ptr
<
GraphLoader
>
make_fbs_loader
(
std
::
unique_ptr
<
InputFile
>
file
);
std
::
unique_ptr
<
GraphDumper
>
make_fbs_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
);
std
::
unique_ptr
<
GraphLoader
>
make_fbs_v2_loader
(
std
::
unique_ptr
<
InputFile
>
file
);
std
::
unique_ptr
<
GraphDumper
>
make_fbs_v2_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
);
bool
is_fbs_file
(
InputFile
&
file
);
bool
is_fbs_v2_file
(
InputFile
&
file
);
bool
GraphDumper
::
should_remove_in_dump
(
cg
::
OperatorNodeBase
*
opr
)
{
#if MGB_ENABLE_GRAD
...
...
@@ -73,6 +77,11 @@ std::unique_ptr<GraphDumper> GraphDumper::make(
case
GraphDumpFormat
::
FLATBUFFERS
:
#if MGB_ENABLE_FBS_SERIALIZATION
return
make_fbs_dumper
(
std
::
move
(
file
));
#endif
MGB_FALLTHRU
case
GraphDumpFormat
::
FLATBUFFERS_V2
:
#if MGB_ENABLE_FBS_SERIALIZATION
return
make_fbs_v2_dumper
(
std
::
move
(
file
));
#endif
MGB_FALLTHRU
default:
...
...
@@ -87,6 +96,11 @@ std::unique_ptr<GraphLoader> GraphLoader::make(
case
GraphDumpFormat
::
FLATBUFFERS
:
#if MGB_ENABLE_FBS_SERIALIZATION
return
make_fbs_loader
(
std
::
move
(
file
));
#endif
MGB_FALLTHRU
case
GraphDumpFormat
::
FLATBUFFERS_V2
:
#if MGB_ENABLE_FBS_SERIALIZATION
return
make_fbs_v2_loader
(
std
::
move
(
file
));
#endif
MGB_FALLTHRU
default:
...
...
@@ -100,6 +114,9 @@ Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file)
if
(
is_fbs_file
(
file
))
{
return
GraphDumpFormat
::
FLATBUFFERS
;
}
if
(
is_fbs_v2_file
(
file
))
{
return
GraphDumpFormat
::
FLATBUFFERS_V2
;
}
#endif
return
{};
}
...
...
src/serialization/impl/serializer_oss.cpp
浏览文件 @
a694fb33
...
...
@@ -11,17 +11,16 @@
*/
#if MGB_ENABLE_FBS_SERIALIZATION
#include "batched_device_value_loader.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/serialization/helper.h"
#include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/serializer.h"
#include "
megbrain/versi
on.h"
#include "
serializer_oss_comm
on.h"
#include <flatbuffers/flatbuffers.h>
...
...
@@ -33,47 +32,8 @@ using namespace mgb;
using
namespace
mgb
::
serialization
;
namespace
{
constexpr
uint32_t
MGB_VERSION
=
(
MGE_MAJOR
*
1000
+
MGE_MINOR
)
*
100
+
MGE_PATCH
;
constexpr
uint32_t
MGB_MAGIC
=
0x4342474D
;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr
uint32_t
MAGIC_V0
=
0x5342474D
;
// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the
// old magic(MAGIC_V0) is false.
bool
magic_compare
=
true
;
template
<
typename
T
>
bool
contains_any_in_set
(
const
SmallVector
<
T
>&
list
,
const
ThinHashSet
<
T
>&
set
)
{
for
(
const
auto
&
x
:
list
)
{
if
(
set
.
count
(
x
))
{
return
true
;
}
}
return
false
;
}
void
check_tensor_value_valid
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
)
{
bool
cond_normal
=
tensor
.
layout
().
format
.
is_default
()
&&
tensor
.
layout
().
is_physical_contiguous
();
bool
cond_lowbit
=
tensor
.
layout
().
dtype
.
is_quantized_lowbit
()
&&
tensor
.
layout
().
format
.
is_lowbit_aligned
()
&&
tensor
.
layout
().
is_contiguous
();
mgb_assert
(
cond_normal
||
cond_lowbit
,
"non-contiguous tensor: name=%s layout=%s"
,
name
.
c_str
(),
tensor
.
layout
().
to_string
().
c_str
());
if
(
tensor
.
dtype
()
==
dtype
::
Float32
())
{
auto
ptr
=
tensor
.
ptr
<
float
>
();
for
(
size_t
i
=
0
,
it
=
tensor
.
shape
().
total_nr_elems
();
i
<
it
;
++
i
)
{
if
(
!
std
::
isfinite
(
ptr
[
i
]))
{
mgb_log_warn
(
"invalid tensor value in %s: %g"
,
name
.
c_str
(),
ptr
[
i
]);
break
;
}
}
}
}
//! feature bits for backward compatibility; default value should be 0
struct
FeatureBits64
{
//! reserved for new fields
...
...
@@ -947,13 +907,6 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file) {
return
std
::
make_unique
<
GraphLoaderOSS
>
(
std
::
move
(
file
));
}
bool
is_fbs_file
(
InputFile
&
file
)
{
uint64_t
magic_with_reserved
=
0
;
file
.
read
(
&
magic_with_reserved
,
sizeof
(
magic_with_reserved
));
file
.
skip
(
-
sizeof
(
magic_with_reserved
));
return
(
magic_with_reserved
==
MGB_MAGIC
)
||
(
magic_with_reserved
==
MAGIC_V0
);
}
}
// namespace serialization
}
// namespace mgb
...
...
src/serialization/impl/serializer_oss_common.cpp
0 → 100644
浏览文件 @
a694fb33
#if MGB_ENABLE_FBS_SERIALIZATION
#include "serializer_oss_common.h"
namespace
mgb
{
namespace
serialization
{
bool
is_fbs_file
(
InputFile
&
file
)
{
//! check whether the model format is flatbuffer v2
uint64_t
magic_with_reserved
=
0
;
file
.
read
(
&
magic_with_reserved
,
sizeof
(
magic_with_reserved
));
file
.
skip
(
-
sizeof
(
magic_with_reserved
));
return
(
magic_with_reserved
==
MGB_MAGIC
)
||
(
magic_with_reserved
==
MAGIC_V0
);
}
void
check_tensor_value_valid
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
)
{
bool
cond_normal
=
tensor
.
layout
().
format
.
is_default
()
&&
tensor
.
layout
().
is_physical_contiguous
();
bool
cond_lowbit
=
tensor
.
layout
().
dtype
.
is_quantized_lowbit
()
&&
tensor
.
layout
().
format
.
is_lowbit_aligned
()
&&
tensor
.
layout
().
is_contiguous
();
mgb_assert
(
cond_normal
||
cond_lowbit
,
"non-contiguous tensor: name=%s layout=%s"
,
name
.
c_str
(),
tensor
.
layout
().
to_string
().
c_str
());
if
(
tensor
.
dtype
()
==
dtype
::
Float32
())
{
auto
ptr
=
tensor
.
ptr
<
float
>
();
for
(
size_t
i
=
0
,
it
=
tensor
.
shape
().
total_nr_elems
();
i
<
it
;
++
i
)
{
if
(
!
std
::
isfinite
(
ptr
[
i
]))
{
mgb_log_warn
(
"invalid tensor value in %s: %g"
,
name
.
c_str
(),
ptr
[
i
]);
break
;
}
}
}
}
}
// namespace serialization
}
// namespace mgb
#endif
src/serialization/impl/serializer_oss_common.h
0 → 100644
浏览文件 @
a694fb33
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/serialization/serializer.h"
#include "megbrain/version.h"
namespace
mgb
{
namespace
serialization
{
constexpr
uint32_t
MGB_VERSION
=
(
MGE_MAJOR
*
1000
+
MGE_MINOR
)
*
100
+
MGE_PATCH
;
constexpr
uint32_t
MGB_MAGIC
=
0x4342474D
;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr
uint32_t
MAGIC_V0
=
0x5342474D
;
void
check_tensor_value_valid
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
);
template
<
typename
T
>
bool
contains_any_in_set
(
const
SmallVector
<
T
>&
list
,
const
ThinHashSet
<
T
>&
set
)
{
for
(
const
auto
&
x
:
list
)
{
if
(
set
.
count
(
x
))
{
return
true
;
}
}
return
false
;
}
}
// namespace serialization
}
// namespace mgb
#endif
src/serialization/impl/serializer_oss_v2.cpp
0 → 100644
浏览文件 @
a694fb33
此差异已折叠。
点击以展开。
src/serialization/i
mpl
/batched_device_value_loader.h
→
src/serialization/i
nclude/megbrain/serialization
/batched_device_value_loader.h
浏览文件 @
a694fb33
文件已移动
src/serialization/include/megbrain/serialization/dump_format.h
浏览文件 @
a694fb33
...
...
@@ -5,6 +5,7 @@ namespace serialization {
enum
class
GraphDumpFormat
{
FLATBUFFERS
,
FLATBUFFERS_V2
,
};
}
// namespace serialization
...
...
src/serialization/include/megbrain/serialization/opr_load_dump.h
浏览文件 @
a694fb33
...
...
@@ -20,8 +20,12 @@ class FlatBufferBuilder;
}
// namespace flatbuffers
namespace
mgb
{
namespace
serialization
{
constexpr
uint8_t
CURRENT_VERSION
=
2u
;
constexpr
uint8_t
BEGIN_VERSION
=
0u
;
constexpr
uint8_t
VERSION_1
=
1u
;
constexpr
uint8_t
VERSION_2
=
2u
;
namespace
serialization
{
namespace
fbs
{
template
<
typename
T
>
struct
OperatorParamTraits
;
...
...
@@ -187,6 +191,9 @@ class OprLoadContext : public UserDataContainer::UserData {
friend
class
OprLoadContextRawPOD
;
friend
class
OprLoadContextFlatBuffers
;
protected:
virtual
~
OprLoadContext
()
=
default
;
public:
//! get current computing graph
virtual
ComputingGraph
&
graph
()
=
0
;
...
...
@@ -224,6 +231,12 @@ public:
*/
virtual
SharedBuffer
load_shared_buf_with_len
()
=
0
;
/*!
* \brief get the serialization data of the current opr
*
*/
virtual
const
void
*
get_current_opr_data
()
{
return
nullptr
;
};
/*!
* \brief read a param and check that tag matches
*/
...
...
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
0 → 100644
浏览文件 @
a694fb33
#pragma once
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/serialization/internal/schema_v2_generated.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/serializer.h"
#define CAST_TO_FBS_V2_CTX(cvt) static_cast<GraphLoaderOSSV2::OprLoadContextImpl&>(ctx)
namespace
mgb
{
namespace
serialization
{
class
GraphDumperOSSV2
final
:
public
GraphDumper
,
OprDumpContextFlatBuffers
{
const
std
::
unique_ptr
<
OutputFile
>
m_file
;
flatbuffers
::
FlatBufferBuilder
m_builder
;
DumpConfig
m_config
;
DumpResult
m_cur_rst
;
size_t
m_nr_shared_tensor
;
std
::
vector
<
std
::
pair
<
cg
::
OperatorNodeBase
*
,
const
OprRegistryV2
*>>
m_oprs_to_dump
;
ThinHashMap
<
VarNode
*
,
VarNode
*>
m_var_remove_in_dump
;
//! set of output vars specified by user
ThinHashSet
<
VarNode
*>
m_output_vars
;
std
::
unordered_set
<
std
::
string
>
m_used_input_names
,
m_used_param_names
;
//! current opr to be dumped
cg
::
OperatorNodeBase
*
m_cur_opr
=
nullptr
;
// Will be filled in dump_tensor
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
Tensor
>>
m_cur_opr_tensor
;
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
Blob
>>
m_blobs
;
std
::
vector
<
fbs
::
v2
::
OperatorParam
>
m_cur_opr_param_type
;
std
::
vector
<
flatbuffers
::
Offset
<
void
>>
m_cur_opr_param
;
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
MiddleTensor
>>
m_model_middle_tensors
;
ThinHashMap
<
VarNode
*
,
size_t
>
m_var2midtensor_id
;
SymbolVarArray
converter_all_opr_to_compatiable
(
const
SymbolVarArray
&
output_vars
);
void
init_oprs_to_dump
(
const
SymbolVarArray
&
endpoints
);
flatbuffers
::
Offset
<
fbs
::
v2
::
Metadata
>
build_metadata
(
const
Metadata
&
metadata
);
flatbuffers
::
Offset
<
fbs
::
v2
::
Operator
>
build_single_opr
(
cg
::
OperatorNodeBase
*
opr
,
const
OprRegistryV2
*
registry
);
flatbuffers
::
Offset
<
fbs
::
DType
>
build_dtype
(
DType
dtype
);
public:
GraphDumperOSSV2
(
std
::
unique_ptr
<
OutputFile
>
file
)
:
m_file
{
std
::
move
(
file
)}
{}
DumpResult
dump
(
const
SymbolVarArray
&
output_vars
,
const
DumpConfig
&
config
=
{},
const
Metadata
&
metadata
=
{})
override
;
const
GraphDumpConfig
&
config
()
const
override
{
return
m_config
;
}
void
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
override
;
void
append_param
(
uint32_t
type
,
uint32_t
value
)
override
{
static_assert
(
std
::
is_same
<
uint32_t
,
flatbuffers
::
uoffset_t
>::
value
,
"append_param depends on uoffset_t being uint32_t"
);
static_assert
(
std
::
is_standard_layout
<
flatbuffers
::
Offset
<
void
>>::
value
,
"append_param depends on flatbuffers::Offset having "
"standard memory layout"
);
mgb_assert
(
type
!=
fbs
::
v2
::
OperatorParam_NONE
);
m_cur_opr_param_type
.
emplace_back
(
static_cast
<
fbs
::
v2
::
OperatorParam
>
(
type
));
m_cur_opr_param
.
emplace_back
(
value
);
}
flatbuffers
::
FlatBufferBuilder
&
builder
()
override
{
return
m_builder
;
}
void
dump_buf_with_len
(
const
void
*
data
,
uint32_t
size
)
override
;
GraphDumpFormat
format
()
const
override
{
return
GraphDumpFormat
::
FLATBUFFERS_V2
;
}
flatbuffers
::
Offset
<
fbs
::
v2
::
MiddleTensor
>
build_middle_tensor
(
const
SymbolVar
var
);
flatbuffers
::
Offset
<
fbs
::
v2
::
OutputVar
>
build_output_var
(
const
SymbolVar
var
);
flatbuffers
::
Offset
<
void
>
build_tensor_format
(
const
TensorLayout
::
Format
&
format
);
void
set_current_opr
(
cg
::
OperatorNodeBase
*
cur_opr
)
{
m_cur_opr
=
cur_opr
;
}
};
// ----------------------------- Loader --------------------------------------
class
GraphLoaderOSSV2
final
:
public
GraphLoader
{
const
LoadConfig
*
m_cur_load_config
=
nullptr
;
std
::
unique_ptr
<
InputFile
>
m_file
;
SharedBuffer
m_model_buf
{{},
0
};
const
fbs
::
v2
::
Model
*
m_model
;
SharedTensorIDMap
m_shared_tensor_map
;
uint32_t
m_mgb_version
=
0
;
bool
m_model_loaded
=
false
;
void
verify
();
public:
class
OprLoadContextImpl
;
friend
class
OprLoadContextImpl
;
GraphLoaderOSSV2
(
std
::
unique_ptr
<
InputFile
>
input_file
)
:
m_file
{
std
::
move
(
input_file
)}
{}
std
::
unique_ptr
<
InputFile
>
reset_file
(
std
::
unique_ptr
<
InputFile
>
file
)
override
{
file
.
swap
(
m_file
);
return
file
;
}
LoadResult
load
(
const
LoadConfig
&
config
,
bool
rewind
)
override
;
const
SharedTensorIDMap
&
shared_tensor_id_map
()
const
override
{
mgb_assert
(
m_model_loaded
,
"graph not loaded yet"
);
return
m_shared_tensor_map
;
}
GraphDumpFormat
format
()
const
override
{
return
GraphDumpFormat
::
FLATBUFFERS_V2
;
}
};
class
GraphLoaderOSSV2
::
OprLoadContextImpl
final
:
public
OprLoadContextFlatBuffers
{
GraphLoaderOSSV2
*
const
m_loader
;
size_t
m_cur_shared_tensor_idx
=
0
;
std
::
shared_ptr
<
ComputingGraph
>
m_graph
;
LoadResult
::
TensorMap
m_tensor_map
;
VarNodeArray
m_id2varnode
;
std
::
vector
<
const
fbs
::
v2
::
MiddleTensor
*>
m_middle_tensors
;
BatchedDeviceValueLoader
m_device_value_loader
;
const
fbs
::
v2
::
Operator
*
m_current_opr
;
size_t
m_cur_opr_tensor_cnt
;
size_t
m_cur_opr_blob_cnt
;
size_t
m_cur_opr_param_cnt
;
public:
ComputingGraph
&
graph
()
override
{
return
*
m_graph
;
}
const
GraphLoadConfig
&
config
()
const
override
{
return
*
m_loader
->
m_cur_load_config
;
}
std
::
shared_ptr
<
HostTensorND
>
load_tensor
()
override
;
std
::
shared_ptr
<
DeviceTensorND
>
load_tensor_shared
()
override
;
void
load_single_opr
(
const
fbs
::
v2
::
Operator
*
opr
);
OprLoadContextImpl
(
GraphLoaderOSSV2
*
loader
,
uint32_t
version
)
:
OprLoadContextFlatBuffers
(
version
),
m_loader
{
loader
}
{
m_graph
=
loader
->
m_cur_load_config
->
comp_graph
;
if
(
!
m_graph
)
{
m_graph
=
ComputingGraph
::
make
();
}
auto
maker
=
[
this
]()
{
return
std
::
shared_ptr
<
OprLoadContext
>
{
std
::
shared_ptr
<
OprLoadContext
>
{},
this
};
};
auto
got
=
m_graph
->
options
().
user_data
.
get_user_data_or_create
<
OprLoadContext
>
(
maker
);
mgb_assert
(
got
==
this
);
}
~
OprLoadContextImpl
()
noexcept
{
auto
nr
=
m_graph
->
options
().
user_data
.
pop_user_data
<
OprLoadContext
>
();
mgb_assert
(
nr
==
1
);
}
Metadata
load_metadata
();
LoadResult
load_oprs
();
CompNode
load_comp_node
(
const
fbs
::
v2
::
CompNode
*
comp_node
);
void
load_middle_tensor
();
const
void
*
get_next_param
(
uint32_t
enumv
)
override
{
auto
type
=
static_cast
<
fbs
::
v2
::
OperatorParam
>
(
enumv
);
if
(
m_cur_opr_param_cnt
==
0
)
{
m_cur_opr_param_cnt
++
;
if
(
m_current_opr
->
param_type
()
==
type
)
{
return
m_current_opr
->
param
();
}
else
{
mgb_throw
(
SerializationError
,
"The param type is not match when load the opr."
);
}
}
mgb_throw
(
SerializationError
,
"When load multi param in one Operator, please use read_param(index) "
"interface. "
);
}
std
::
string
load_buf_with_len
()
override
{
mgb_assert
(
m_current_opr
->
custom_data
()
&&
m_cur_opr_blob_cnt
<
m_current_opr
->
custom_data
()
->
size
());
auto
blob
=
m_current_opr
->
custom_data
()
->
Get
(
m_cur_opr_blob_cnt
++
);
mgb_assert
(
blob
&&
blob
->
data
());
auto
data
=
blob
->
data
()
->
data
();
return
{
reinterpret_cast
<
const
char
*>
(
data
),
blob
->
data
()
->
size
()};
}
SharedBuffer
load_shared_buf_with_len
()
override
{
mgb_assert
(
m_current_opr
->
custom_data
()
&&
m_cur_opr_blob_cnt
<
m_current_opr
->
custom_data
()
->
size
());
auto
blob
=
m_current_opr
->
custom_data
()
->
Get
(
m_cur_opr_blob_cnt
++
);
mgb_assert
(
blob
&&
blob
->
data
());
auto
size
=
blob
->
data
()
->
size
();
std
::
shared_ptr
<
uint8_t
>
shptr
{
new
uint8_t
[
size
],
[](
uint8_t
*
p
)
{
delete
[]
p
;
}};
memcpy
(
shptr
.
get
(),
blob
->
data
()
->
data
(),
size
);
return
{
std
::
move
(
shptr
),
size
};
}
const
void
*
get_current_opr_data
()
override
{
return
reinterpret_cast
<
const
void
*>
(
m_current_opr
);
}
template
<
class
T
>
T
read_param
(
int
index
)
{
using
SourceType
=
typename
fbs
::
ParamConverter
<
T
>::
FlatBufferType
;
auto
enumv
=
fbs
::
OperatorParamTraits
<
SourceType
>::
enum_value
;
auto
type
=
static_cast
<
fbs
::
v2
::
OperatorParam
>
(
enumv
);
if
(
index
==
0
)
{
mgb_assert
(
m_current_opr
->
param_type
()
==
type
,
"Load param error, the param type is not right."
);
return
fbs
::
ParamConverter
<
T
>::
to_param
(
static_cast
<
const
SourceType
*>
(
m_current_opr
->
param
()));
}
else
{
int
addition_index
=
index
-
1
;
if
(
addition_index
>=
static_cast
<
int
>
(
m_current_opr
->
additional_params
()
->
size
()))
{
mgb_log_warn
(
"Model has no addition param of index %d, just construct a "
"default one."
,
addition_index
);
}
else
{
mgb_assert
(
m_current_opr
->
additional_params_type
()
->
Get
(
addition_index
)
==
type
,
"Load param error, the addition param type is not right."
);
return
fbs
::
ParamConverter
<
T
>::
to_param
(
static_cast
<
const
SourceType
*>
(
m_current_opr
->
additional_params
()
->
Get
(
addition_index
)));
}
}
}
};
}
// namespace serialization
}
// namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录