Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
73ad06ba
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
73ad06ba
编写于
8月 10, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb): fix fbsv2 model format no dump tensor format
GitOrigin-RevId: 29f785b7019bbc5a591d5992020798f6f6a9ae6f
上级
399200b3
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
117 addition
and
84 deletion
+117
-84
imperative/src/impl/ops/opr_attr.cpp
imperative/src/impl/ops/opr_attr.cpp
+2
-2
src/opr/impl/io.sereg.v2.h
src/opr/impl/io.sereg.v2.h
+13
-5
src/opr/test/io.cpp
src/opr/test/io.cpp
+58
-40
src/serialization/impl/opr_shallow_copy.cpp
src/serialization/impl/opr_shallow_copy.cpp
+3
-1
src/serialization/impl/serializer_oss.cpp
src/serialization/impl/serializer_oss.cpp
+10
-9
src/serialization/impl/serializer_oss_v2.cpp
src/serialization/impl/serializer_oss_v2.cpp
+26
-24
src/serialization/include/megbrain/serialization/file.h
src/serialization/include/megbrain/serialization/file.h
+1
-1
src/serialization/include/megbrain/serialization/opr_load_dump.h
...ialization/include/megbrain/serialization/opr_load_dump.h
+1
-1
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
...zation/include/megbrain/serialization/oss_opr_load_dump.h
+3
-1
未找到文件。
imperative/src/impl/ops/opr_attr.cpp
浏览文件 @
73ad06ba
...
...
@@ -27,7 +27,7 @@ class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD {
std
::
shared_ptr
<
DeviceTensorND
>
load_tensor_shared
(
bool
copy_immediatly
=
false
)
override
{
(
void
)
copy_immediatly
;
MGB_MARK_USED_VAR
(
copy_immediatly
)
;
mgb_assert
(
0
);
}
...
...
@@ -56,7 +56,7 @@ public:
}
void
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
{
TensorWriteMethod
method
,
TensorFormat
format
=
{}
)
{
mgb_assert
(
0
);
}
const
serialization
::
GraphDumpConfig
&
config
()
const
{
mgb_assert
(
0
);
}
...
...
src/opr/impl/io.sereg.v2.h
浏览文件 @
73ad06ba
...
...
@@ -72,16 +72,20 @@ struct OprLoadDumpImplV2<opr::SharedDeviceTensorWithFormat, 0> {
auto
&&
opr
=
opr_
.
cast_final_safe
<
opr
::
SharedDeviceTensorWithFormat
>
();
HostTensorND
val
;
val
.
copy_from
(
opr
.
get_dev_tensor
()).
sync
();
ctx
.
dump_tensor
({},
val
,
Meth
::
VALUE_ANONYMOUS
);
ctx
.
dump_tensor
(
{},
val
,
Meth
::
VALUE_ANONYMOUS
,
opr
.
get_dev_tensor
().
layout
().
format
);
}
static
cg
::
OperatorNodeBase
*
load
(
OprLoadContext
&
ctx
,
const
cg
::
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
mgb_assert
(
inputs
.
empty
());
auto
val
=
ctx
.
load_tensor
();
auto
&
fbs_ctx
=
CAST_TO_FBS_V2_CTX
(
ctx
);
auto
val
=
fbs_ctx
.
load_tensor
();
auto
format
=
fbs_ctx
.
load_tensor_format
(
0
);
TensorLayout
layout_with_format
=
{
val
->
shape
(),
val
->
dtype
(),
format
};
auto
dev_val
=
std
::
make_shared
<
DeviceTensorND
>
(
val
->
comp_node
(),
val
->
layout
()
);
std
::
make_shared
<
DeviceTensorND
>
(
val
->
comp_node
(),
layout_with_format
);
dev_val
->
copy_from_fixlayout
(
*
val
);
auto
out_var
=
opr
::
SharedDeviceTensorWithFormat
::
make
(
ctx
.
graph
(),
dev_val
,
config
);
...
...
@@ -136,7 +140,9 @@ struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> {
HostTensorND
val
;
auto
value
=
*
opr
.
values
()[
i
];
val
.
copy_from
(
value
).
sync
();
ctx
.
dump_tensor
(
opr
.
output
(
i
)
->
name
(),
val
,
Meth
::
VALUE_SHARED
);
ctx
.
dump_tensor
(
opr
.
output
(
i
)
->
name
(),
val
,
Meth
::
VALUE_SHARED
,
value
.
layout
().
format
);
}
}
...
...
@@ -152,10 +158,12 @@ struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> {
nr
=
fopr
->
tensors
()
->
size
();
}
Opr
::
ValueArray
values
(
nr
);
size_t
id
=
0
;
for
(
auto
&&
i
:
values
)
{
i
=
ctx
.
load_tensor_shared
();
//! set tensor format
TensorLayout
layout_with_format
=
i
->
layout
();
auto
format
=
fbs_ctx
.
load_tensor_format
(
id
++
);
TensorLayout
layout_with_format
{
i
->
layout
(),
i
->
layout
().
dtype
,
format
};
if
(
i
->
storage
().
comp_node
().
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
())
{
...
...
src/opr/test/io.cpp
浏览文件 @
73ad06ba
...
...
@@ -498,48 +498,66 @@ TEST(TestOprIO, MultipleDeviceTensorWithFormatHolderCpu) {
auto
fname
=
GET_OUTPUT_FILE
();
auto
cn
=
CompNode
::
load
(
"cpu0"
);
HostTensorGenerator
<>
gen
;
{
// dump
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
)).
rename
(
name
);
auto
test
=
[
&
](
serialization
::
GraphDumpFormat
format
)
{
{
// dump
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
);
};
auto
host_x
=
gen
({
8
,
8
,
8
,
8
},
cn
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
,
{
"x"
});
opr
::
Convolution
::
Param
param
;
param
.
pad_h
=
param
.
pad_w
=
0
;
auto
w1
=
mkcvar
(
"w1"
,
{
4
,
8
,
3
,
3
}),
conv1
=
opr
::
Convolution
::
make
(
x
,
w1
,
param
);
auto
w2
=
mkcvar
(
"w2"
,
{
4
,
4
,
3
,
3
}),
conv2
=
opr
::
Convolution
::
make
(
conv1
,
w2
,
param
);
auto
y
=
opr
::
Elemwise
::
make
({
conv2
},
opr
::
Elemwise
::
Param
::
Mode
::
RELU
);
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_nhwcd4
();
SymbolVar
y_opt
=
gopt
::
optimize_for_inference
({
y
},
options
)[
0
].
rename
(
"out"
);
auto
dumper
=
serialization
::
GraphDumper
::
make
(
serialization
::
OutputFile
::
make_fs
(
fname
.
c_str
()),
format
);
serialization
::
GraphDumper
::
DumpConfig
config
;
config
.
keep_param_name
=
true
;
dumper
->
dump
({
y_opt
},
config
);
}
auto
loader
=
serialization
::
GraphLoader
::
make
(
serialization
::
InputFile
::
make_fs
(
fname
.
c_str
()),
format
);
auto
load
=
[
&
](
CompNode
dest_cn
)
{
auto
dest_cn_loc
=
dest_cn
.
locator_logical
();
auto
rst
=
loader
->
load
({[
&
](
CompNode
::
Locator
&
loc
)
{
loc
=
dest_cn_loc
;
}});
HostTensorND
host_z
,
host_z_expect
;
auto
func
=
rst
.
graph_compile
(
{
make_callback_copy
(
rst
.
output_var_map
.
at
(
"out"
),
host_z
)});
func
->
execute
();
func
->
wait
();
auto
&&
shared_tensor_map
=
loader
->
shared_tensor_id_map
();
bool
cd4
=
false
;
for
(
auto
&&
i
:
shared_tensor_map
)
{
auto
&&
shared_tensor
=
i
.
second
.
begin
()
->
second
;
if
(
shared_tensor
->
format
().
type
()
==
TensorFormat
::
Type
::
IMAGE2D_PACK4
)
{
cd4
=
true
;
}
}
ASSERT_TRUE
(
cd4
);
};
auto
host_x
=
gen
({
8
,
8
,
8
,
8
},
cn
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
,
{
"x"
});
opr
::
Convolution
::
Param
param
;
param
.
pad_h
=
param
.
pad_w
=
0
;
auto
w1
=
mkcvar
(
"w1"
,
{
4
,
8
,
3
,
3
}),
conv1
=
opr
::
Convolution
::
make
(
x
,
w1
,
param
);
auto
w2
=
mkcvar
(
"w2"
,
{
4
,
4
,
3
,
3
}),
conv2
=
opr
::
Convolution
::
make
(
conv1
,
w2
,
param
);
auto
y
=
opr
::
Elemwise
::
make
({
conv2
},
opr
::
Elemwise
::
Param
::
Mode
::
RELU
);
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_nhwcd4
();
SymbolVar
y_opt
=
gopt
::
optimize_for_inference
({
y
},
options
)[
0
].
rename
(
"out"
);
auto
dumper
=
serialization
::
GraphDumper
::
make
(
serialization
::
OutputFile
::
make_fs
(
fname
.
c_str
()));
serialization
::
GraphDumper
::
DumpConfig
config
;
config
.
keep_param_name
=
true
;
dumper
->
dump
({
y_opt
},
config
);
}
auto
loader
=
serialization
::
GraphLoader
::
make
(
serialization
::
InputFile
::
make_fs
(
fname
.
c_str
()));
auto
load
=
[
&
](
CompNode
dest_cn
)
{
auto
dest_cn_loc
=
dest_cn
.
locator_logical
();
auto
rst
=
loader
->
load
({[
&
](
CompNode
::
Locator
&
loc
)
{
loc
=
dest_cn_loc
;
}});
HostTensorND
host_z
,
host_z_expect
;
auto
func
=
rst
.
graph_compile
(
{
make_callback_copy
(
rst
.
output_var_map
.
at
(
"out"
),
host_z
)});
func
->
execute
();
load
(
cn
);
};
load
(
cn
);
test
({});
test
(
serialization
::
GraphDumpFormat
::
FLATBUFFERS_V2
);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/serialization/impl/opr_shallow_copy.cpp
浏览文件 @
73ad06ba
...
...
@@ -32,7 +32,9 @@ class OprDumpContextMemory final : public OprDumpContextRawPOD {
}
void
dump_tensor
(
const
std
::
string
&
,
const
HostTensorND
&
,
TensorWriteMethod
)
override
{
const
std
::
string
&
,
const
HostTensorND
&
,
TensorWriteMethod
,
TensorFormat
format
=
{})
override
{
MGB_MARK_USED_VAR
(
format
);
mgb_throw
(
GraphError
,
"OprDumpContextMemory does not support dump tensor"
);
}
...
...
src/serialization/impl/serializer_oss.cpp
浏览文件 @
73ad06ba
...
...
@@ -92,7 +92,7 @@ public:
const
GraphDumpConfig
&
config
()
const
override
{
return
m_config
;
}
void
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
override
;
TensorWriteMethod
method
,
TensorFormat
format
=
{}
)
override
;
flatbuffers
::
FlatBufferBuilder
&
builder
()
override
{
return
m_builder
;
}
void
append_param
(
uint32_t
type
,
uint32_t
value
)
override
{
static_assert
(
...
...
@@ -359,7 +359,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
}
void
GraphDumperOSS
::
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
{
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
,
TensorFormat
)
{
using
namespace
flatbuffers
;
using
Meth
=
TensorWriteMethod
;
mgb_assert
(
...
...
@@ -671,17 +672,17 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl::load_tensor_
sh_reg
.
first
=
tensor
->
name
()
->
str
();
}
if
(
comp_node
.
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
())
{
if
(
comp_node
.
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
()
||
copy_immediatly
)
{
// directly forward CPU memory
HostTensorND
hv
{
comp_node
};
load_tensor_value
(
&
hv
,
layout
,
tensor
);
sh_ptr_ref
=
std
::
make_shared
<
DeviceTensorND
>
();
*
sh_ptr_ref
=
DeviceTensorND
::
make_proxy
(
hv
);
}
else
if
(
copy_immediatly
)
{
HostTensorND
hv
{
CompNode
::
default_cpu
()};
load_tensor_value
(
&
hv
,
layout
,
tensor
);
sh_ptr_ref
=
std
::
make_shared
<
DeviceTensorND
>
();
sh_ptr_ref
->
comp_node
(
comp_node
).
copy_from
(
hv
).
sync
();
if
(
comp_node
.
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
())
{
*
sh_ptr_ref
=
DeviceTensorND
::
make_proxy
(
hv
);
}
else
{
mgb_assert
(
copy_immediatly
);
sh_ptr_ref
->
comp_node
(
comp_node
).
copy_from
(
hv
).
sync
();
}
}
else
{
// use lazy load for non-CPU devices
HostTensorND
hv
{
CompNode
::
default_cpu
()};
...
...
src/serialization/impl/serializer_oss_v2.cpp
浏览文件 @
73ad06ba
...
...
@@ -455,7 +455,8 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
}
void
GraphDumperOSSV2
::
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
{
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
,
TensorFormat
format
)
{
using
namespace
flatbuffers
;
using
Meth
=
TensorWriteMethod
;
mgb_assert
(
...
...
@@ -510,8 +511,8 @@ void GraphDumperOSSV2::dump_tensor(
m_builder
.
CreateSharedString
(
tensor
.
comp_node
().
to_string_logical
()));
auto
fdtype
=
build_dtype
(
layout
.
dtype
);
auto
fformat_type
=
get_flatbuffer_tensor_format_type
(
layout
.
format
);
auto
fformat
=
build_tensor_format
(
layout
.
format
);
auto
fformat_type
=
get_flatbuffer_tensor_format_type
(
format
);
auto
fformat
=
build_tensor_format
(
format
);
auto
serialized_tensor
=
fbs
::
v2
::
CreateTensor
(
m_builder
,
fbname
,
fshape
,
fcomp_node
,
fdtype
,
fformat_type
,
fformat
,
data
);
m_cur_opr_tensor
.
emplace_back
(
serialized_tensor
);
...
...
@@ -605,7 +606,7 @@ CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node(
return
CompNode
::
load
(
loc
);
}
TensorFormat
load
_tensor_format
(
TensorFormat
get
_tensor_format
(
const
fbs
::
v2
::
TensorFormat
fformat_type
,
const
void
*
fformat
,
const
CompNode
&
comp_node
)
{
switch
(
fformat_type
)
{
...
...
@@ -631,8 +632,7 @@ TensorFormat load_tensor_format(
}
}
TensorLayout
load_tensor_layout
(
const
fbs
::
v2
::
Tensor
*
tensor
,
const
CompNode
&
comp_node
)
{
TensorLayout
load_tensor_layout_without_format
(
const
fbs
::
v2
::
Tensor
*
tensor
)
{
TensorLayout
layout
;
if
(
tensor
->
shape
())
{
layout
.
ndim
=
tensor
->
shape
()
->
size
();
...
...
@@ -642,14 +642,21 @@ TensorLayout load_tensor_layout(
// modify data type inplace for TensorLayout
layout
.
modify_dtype_inplace
(
fbs
::
intl
::
load_dtype
(
tensor
->
dtype
()));
}
if
(
tensor
->
format
()
&&
tensor
->
format_type
())
{
layout
.
format
=
load_tensor_format
(
tensor
->
format_type
(),
tensor
->
format
(),
comp_node
);
}
layout
.
init_contiguous_stride
();
return
layout
;
}
TensorFormat
GraphLoaderOSSV2
::
OprLoadContextImpl
::
load_tensor_format
(
size_t
id
)
{
mgb_assert
(
m_current_opr
->
tensors
()
&&
id
<
m_current_opr
->
tensors
()
->
size
());
auto
tensor
=
m_current_opr
->
tensors
()
->
Get
(
id
);
auto
comp_node
=
load_comp_node
(
tensor
->
comp_node
());
TensorFormat
format
;
if
(
tensor
->
format
()
&&
tensor
->
format_type
())
{
format
=
get_tensor_format
(
tensor
->
format_type
(),
tensor
->
format
(),
comp_node
);
}
return
format
;
}
//! the opr loader should make sure the exist of tensors and the number of
//! tensor, here just assert it.
std
::
shared_ptr
<
HostTensorND
>
GraphLoaderOSSV2
::
OprLoadContextImpl
::
load_tensor
()
{
...
...
@@ -658,7 +665,7 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor(
m_cur_opr_tensor_cnt
<
m_current_opr
->
tensors
()
->
size
());
auto
tensor
=
m_current_opr
->
tensors
()
->
Get
(
m_cur_opr_tensor_cnt
++
);
auto
comp_node
=
load_comp_node
(
tensor
->
comp_node
());
auto
layout
=
load_tensor_layout
(
tensor
,
comp_node
);
auto
layout
=
load_tensor_layout
_without_format
(
tensor
);
auto
ret
=
std
::
make_shared
<
HostTensorND
>
(
comp_node
,
layout
);
auto
&&
loader
=
m_loader
->
m_cur_load_config
->
tensor_value_loader
;
...
...
@@ -692,7 +699,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
m_cur_opr_tensor_cnt
<
m_current_opr
->
tensors
()
->
size
());
auto
tensor
=
m_current_opr
->
tensors
()
->
Get
(
m_cur_opr_tensor_cnt
++
);
auto
comp_node
=
load_comp_node
(
tensor
->
comp_node
());
auto
layout
=
load_tensor_layout
(
tensor
,
comp_node
);
auto
layout
=
load_tensor_layout
_without_format
(
tensor
);
mgb_assert
(
tensor
->
data
());
if
(
m_loader
->
m_shared_tensor_map
.
size
()
<=
m_cur_shared_tensor_idx
)
{
m_loader
->
m_shared_tensor_map
.
resize
(
m_cur_shared_tensor_idx
+
5
);
...
...
@@ -712,7 +719,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
shared_pair
.
first
=
tensor
->
name
()
->
str
();
}
if
(
comp_node
.
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
())
{
if
(
comp_node
.
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
()
||
copy_immediatly
)
{
// directly forward CPU memory
shared_tensor_ref
=
std
::
make_shared
<
DeviceTensorND
>
();
HostTensorND
hv
{
comp_node
};
...
...
@@ -722,18 +729,13 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
hv
,
tensor
->
data
()
->
data
(),
tensor
->
data
()
->
size
(),
m_loader
->
m_file
->
is_shared_memory
());
}
*
shared_tensor_ref
=
DeviceTensorND
::
make_proxy
(
hv
);
m_tensor_alignment
->
add_device_tensor
(
shared_tensor_ref
);
}
else
if
(
copy_immediatly
)
{
HostTensorND
hv
{
CompNode
::
default_cpu
()};
shared_tensor_ref
=
std
::
make_shared
<
DeviceTensorND
>
();
if
(
tensor
->
data
()
&&
tensor
->
data
()
->
size
()
>
0
)
{
hv
.
dtype
(
layout
.
dtype
).
resize
(
layout
);
fill_tensor_memory
(
hv
,
tensor
->
data
()
->
data
(),
tensor
->
data
()
->
size
(),
m_loader
->
m_file
->
is_shared_memory
());
if
(
comp_node
.
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
())
{
*
shared_tensor_ref
=
DeviceTensorND
::
make_proxy
(
hv
);
m_tensor_alignment
->
add_device_tensor
(
shared_tensor_ref
);
}
else
{
mgb_assert
(
copy_immediatly
);
shared_tensor_ref
->
comp_node
(
comp_node
).
copy_from
(
hv
).
sync
();
}
shared_tensor_ref
->
comp_node
(
comp_node
).
copy_from
(
hv
).
sync
();
}
else
{
// use lazy load for non-CPU devices
HostTensorND
hv
{
CompNode
::
default_cpu
()};
...
...
src/serialization/include/megbrain/serialization/file.h
浏览文件 @
73ad06ba
...
...
@@ -47,7 +47,7 @@ public:
//! whether this can be write
virtual
bool
writable
()
{
return
false
;
}
//!
whether
this file have been wrote
//!
tag
this file have been wrote
virtual
void
have_modified
()
{}
/*!
...
...
src/serialization/include/megbrain/serialization/opr_load_dump.h
浏览文件 @
73ad06ba
...
...
@@ -63,7 +63,7 @@ public:
*/
virtual
void
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
=
0
;
TensorWriteMethod
method
,
TensorFormat
format
=
{}
)
=
0
;
//! get associated global configuration
virtual
const
GraphDumpConfig
&
config
()
const
=
0
;
...
...
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
浏览文件 @
73ad06ba
...
...
@@ -63,7 +63,7 @@ public:
void
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
override
;
TensorWriteMethod
method
,
TensorFormat
format
=
{}
)
override
;
void
append_param
(
uint32_t
type
,
uint32_t
value
)
override
{
static_assert
(
...
...
@@ -148,6 +148,8 @@ public:
return
*
m_loader
->
m_cur_load_config
;
}
TensorFormat
load_tensor_format
(
size_t
id
);
//! shared or copy the loaded flatbuffer memory to the CPU tensor, this can reduce
//! the memory used when load model, but should consider the memory
//! alignment
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录