Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
896b0193
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
896b0193
编写于
9月 26, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb): fix zero copy error when model end with memory forward opr
GitOrigin-RevId: 2eba697d85a9970ba8e947cc66cdc8dbbcc32242
上级
babecba2
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
153 addition
and
32 deletion
+153
-32
lite/src/mge/network_impl.cpp
lite/src/mge/network_impl.cpp
+0
-29
lite/src/mge/network_impl.h
lite/src/mge/network_impl.h
+0
-3
src/gopt/test/network.cpp
src/gopt/test/network.cpp
+30
-0
src/gopt/test/network.h
src/gopt/test/network.h
+6
-0
src/gopt/test/no_memory_copy.cpp
src/gopt/test/no_memory_copy.cpp
+90
-0
src/serialization/impl/serializer.cpp
src/serialization/impl/serializer.cpp
+27
-0
未找到文件。
lite/src/mge/network_impl.cpp
浏览文件 @
896b0193
...
@@ -435,31 +435,6 @@ void NetworkImplDft::cross_compnode_model_detect() {
...
@@ -435,31 +435,6 @@ void NetworkImplDft::cross_compnode_model_detect() {
m_nr_device_type
=
nr_used_device_type
.
size
();
m_nr_device_type
=
nr_used_device_type
.
size
();
}
}
void
NetworkImplDft
::
adapt_option_valid
()
{
auto
&&
options
=
m_load_config
.
comp_graph
->
options
();
if
(
m_user_config
->
options
.
force_output_use_user_specified_memory
)
{
for
(
auto
&&
out
:
m_load_result
.
output_var_list
)
{
auto
opr
=
out
.
node
()
->
owner_opr
();
//! all the dest operator inherit from ReadonlyFwdHelper can't
//! support force_output_use_user_specified_memory options
if
(
opr
->
try_cast_final
<
mgb
::
opr
::
Reshape
>
()
||
opr
->
try_cast_final
<
mgb
::
opr
::
Broadcast
>
()
||
opr
->
try_cast_final
<
mgb
::
opr
::
Subtensor
>
()
||
opr
->
try_cast_final
<
mgb
::
opr
::
AxisAddRemove
>
()
||
opr
->
try_cast_final
<
mgb
::
opr
::
Dimshuffle
>
())
{
m_user_config
->
options
.
force_output_use_user_specified_memory
=
false
;
options
.
force_output_use_user_specified_memory
=
false
;
LITE_WARN
(
"detect the unsupported dest operator %s when config "
"force_output_use_user_specified_memory, set "
"force_output_use_user_specified_memory to false
\n
"
,
opr
->
cname
());
break
;
}
}
}
}
void
NetworkImplDft
::
layout_transform_optimization
()
{
void
NetworkImplDft
::
layout_transform_optimization
()
{
if
(
m_set_layout_transform
)
{
if
(
m_set_layout_transform
)
{
mgb
::
ThinHashMap
<
mgb
::
SymbolVar
,
mgb
::
SymbolVar
>
out_var_map
;
mgb
::
ThinHashMap
<
mgb
::
SymbolVar
,
mgb
::
SymbolVar
>
out_var_map
;
...
@@ -611,10 +586,6 @@ void NetworkImplDft::configure_after_loaded() {
...
@@ -611,10 +586,6 @@ void NetworkImplDft::configure_after_loaded() {
layout_transform_optimization
();
layout_transform_optimization
();
//! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply.
adapt_option_valid
();
//! find how many compnode the model has, this should call before update_io
//! find how many compnode the model has, this should call before update_io
cross_compnode_model_detect
();
cross_compnode_model_detect
();
...
...
lite/src/mge/network_impl.h
浏览文件 @
896b0193
...
@@ -239,9 +239,6 @@ private:
...
@@ -239,9 +239,6 @@ private:
//! optimized output tensor copy
//! optimized output tensor copy
void
output_tensor_copy_optimize
(
Var
var
,
std
::
shared_ptr
<
Tensor
>
tensor
);
void
output_tensor_copy_optimize
(
Var
var
,
std
::
shared_ptr
<
Tensor
>
tensor
);
//! adapt option valid, it should call after update_io
void
adapt_option_valid
();
//! configure and optimize network after loaded
//! configure and optimize network after loaded
void
configure_after_loaded
();
void
configure_after_loaded
();
...
...
src/gopt/test/network.cpp
浏览文件 @
896b0193
#include "./network.h"
#include "./network.h"
#include "megbrain/opr/tensor_manip.h"
using
namespace
mgb
;
using
namespace
mgb
;
...
@@ -137,6 +138,35 @@ SymbolVar Network::add_concat(SymbolVar f, SymbolVar g, int axis) {
...
@@ -137,6 +138,35 @@ SymbolVar Network::add_concat(SymbolVar f, SymbolVar g, int axis) {
return
opr
::
Concat
::
make
({
f
,
g
},
axis
);
return
opr
::
Concat
::
make
({
f
,
g
},
axis
);
}
}
SymbolVar
Network
::
add_dimshuffle
(
SymbolVar
f
,
std
::
vector
<
int
>
pattern
)
{
return
opr
::
Dimshuffle
::
make
(
f
,
pattern
);
}
SymbolVar
Network
::
add_axisaddremove
(
SymbolVar
f
)
{
return
opr
::
AxisAddRemove
::
make
(
f
,
{{
opr
::
AxisAddRemove
::
AxisDesc
::
Method
::
REMOVE
,
{
0
}}});
}
SymbolVar
Network
::
add_subtensor
(
SymbolVar
f
)
{
using
AIdx
=
opr
::
indexing
::
AxisIndexer
;
return
opr
::
Subtensor
::
make
(
f
,
{
AIdx
::
make_interval
(
0
,
f
.
make_scalar
(
0
),
None
,
None
)});
}
SymbolVar
Network
::
add_reshape
(
SymbolVar
f
)
{
auto
shp
=
opr
::
GetVarShape
::
make
(
f
);
return
opr
::
Reshape
::
make
(
f
,
shp
);
}
SymbolVar
Network
::
add_broadcast
(
SymbolVar
f
)
{
auto
shp
=
opr
::
GetVarShape
::
make
(
f
);
return
opr
::
Broadcast
::
make
(
f
,
shp
);
}
SymbolVar
Network
::
add_copy
(
SymbolVar
f
)
{
return
opr
::
Copy
::
make
(
f
);
}
SymbolVar
mgb
::
create_block
(
SymbolVar
mgb
::
create_block
(
Network
&
network
,
SymbolVar
f_in
,
size_t
stride
,
size_t
num_outputs1
,
Network
&
network
,
SymbolVar
f_in
,
size_t
stride
,
size_t
num_outputs1
,
bool
has_proj
,
DType
out_dtype
)
{
bool
has_proj
,
DType
out_dtype
)
{
...
...
src/gopt/test/network.h
浏览文件 @
896b0193
...
@@ -53,6 +53,12 @@ public:
...
@@ -53,6 +53,12 @@ public:
opr
::
Pooling
::
Param
::
Mode
mode
=
opr
::
Pooling
::
Param
::
Mode
::
MAX
);
opr
::
Pooling
::
Param
::
Mode
mode
=
opr
::
Pooling
::
Param
::
Mode
::
MAX
);
SymbolVar
add_type_cvt
(
SymbolVar
f
,
DType
out_dtype
=
dtype
::
Float32
());
SymbolVar
add_type_cvt
(
SymbolVar
f
,
DType
out_dtype
=
dtype
::
Float32
());
SymbolVar
add_concat
(
SymbolVar
f
,
SymbolVar
g
,
int
axis
=
0
);
SymbolVar
add_concat
(
SymbolVar
f
,
SymbolVar
g
,
int
axis
=
0
);
SymbolVar
add_dimshuffle
(
SymbolVar
f
,
std
::
vector
<
int
>
pattern
);
SymbolVar
add_axisaddremove
(
SymbolVar
f
);
SymbolVar
add_subtensor
(
SymbolVar
f
);
SymbolVar
add_reshape
(
SymbolVar
f
);
SymbolVar
add_broadcast
(
SymbolVar
f
);
SymbolVar
add_copy
(
SymbolVar
f
);
};
};
SymbolVar
create_block
(
SymbolVar
create_block
(
...
...
src/gopt/test/no_memory_copy.cpp
浏览文件 @
896b0193
...
@@ -45,6 +45,35 @@ struct TestGraph {
...
@@ -45,6 +45,35 @@ struct TestGraph {
m_out_var
=
m_network
->
add_concat
(
f
,
-
f
);
m_out_var
=
m_network
->
add_concat
(
f
,
-
f
);
}
}
void
create_relayout_out_graph
(
int
mem_forward_opr_type
)
{
input_tensor
=
m_gen
({
1
,
3
,
32
,
32
},
m_cn
);
auto
input
=
opr
::
Host2DeviceCopy
::
make
(
*
m_network
->
graph
,
input_tensor
,
m_cn
)
.
rename
(
"input"
);
auto
f
=
m_network
->
add_conv
(
input
,
4
,
{
3
,
3
},
dtype
::
Float32
(),
true
,
{
2
,
2
},
{
0
,
0
});
f
=
m_network
->
add_elemwise
(
{
f
},
dtype
::
Float32
(),
opr
::
Elemwise
::
Param
::
Mode
::
EXP
);
f
=
m_network
->
add_conv
(
f
,
8
,
{
3
,
3
},
dtype
::
Float32
(),
true
,
{
1
,
1
},
{
1
,
1
});
f
=
m_network
->
add_pooling
(
f
,
{
2
,
2
},
{
2
,
2
});
//! dimshuffle
if
(
mem_forward_opr_type
==
0
)
{
f
=
m_network
->
add_dimshuffle
(
f
,
{
0
,
2
,
3
,
1
});
//! BroadCast
}
else
if
(
mem_forward_opr_type
==
1
)
{
f
=
m_network
->
add_broadcast
(
f
);
//! Subtensor
}
else
if
(
mem_forward_opr_type
==
2
)
{
f
=
m_network
->
add_subtensor
(
f
);
//! AxisAddRemove
}
else
if
(
mem_forward_opr_type
==
3
)
{
f
=
m_network
->
add_axisaddremove
(
f
);
//! Reshape
}
else
if
(
mem_forward_opr_type
==
4
)
{
f
=
m_network
->
add_reshape
(
f
);
}
m_out_var
=
m_network
->
add_copy
(
f
);
}
void
create_graph_with_subtensor_forward
()
{
void
create_graph_with_subtensor_forward
()
{
input_tensor
=
m_gen
({
2
,
3
,
32
,
32
},
m_cn
);
input_tensor
=
m_gen
({
2
,
3
,
32
,
32
},
m_cn
);
auto
input
=
opr
::
Host2DeviceCopy
::
make
(
*
m_network
->
graph
,
input_tensor
,
m_cn
)
auto
input
=
opr
::
Host2DeviceCopy
::
make
(
*
m_network
->
graph
,
input_tensor
,
m_cn
)
...
@@ -211,6 +240,67 @@ TEST(TestNoCopy, IONoCopyPtrEQ) {
...
@@ -211,6 +240,67 @@ TEST(TestNoCopy, IONoCopyPtrEQ) {
}
}
}
}
namespace
{
auto
test_memory_forward_io_no_copy
(
int
opr_type
,
TensorShape
shape
)
{
auto
test_graph
=
TestGraph
();
auto
compute_graph
=
test_graph
.
m_network
->
graph
;
compute_graph
->
options
().
force_output_use_user_specified_memory
=
true
;
test_graph
.
create_relayout_out_graph
(
opr_type
);
HostTensorND
truth
;
auto
func
=
test_graph
.
compile_without_copy
();
//! because the output tensor not assign user memory, so it will wrong
ASSERT_THROW
(
func
->
execute
(),
MegBrainError
);
auto
&&
outvar
=
func
->
get_output_vars
()[
0
];
ASSERT_EQ
(
outvar
,
test_graph
.
m_out_var
.
node
());
size_t
times
=
10
;
for
(
size_t
i
=
0
;
i
<
times
;
i
++
)
{
auto
input_tensor
=
test_graph
.
input_tensor
;
auto
layout
=
input_tensor
->
layout
();
size_t
length
=
layout
.
total_nr_elems
();
auto
storage
=
TensorStorage
<
HostTensorStorageTrait
>
(
test_graph
.
m_cn
);
storage
.
ensure_size
(
length
*
sizeof
(
float
));
float
*
ptr
=
storage
.
ptr
()
->
as
<
float
>
();
for
(
size_t
d
=
0
;
d
<
length
;
d
++
)
{
ptr
[
d
]
=
i
/
5
+
3
;
}
input_tensor
->
reset
(
storage
,
layout
);
DeviceTensorND
dv
(
test_graph
.
m_cn
,
shape
);
outvar
->
init_mem_plan
(
&
dv
);
outvar
->
reset_dev_tensor_from_tensor
(
dv
);
func
->
execute
();
func
->
wait
();
if
(
i
%
5
==
0
)
{
truth
.
copy_from
(
func
->
get_output_vars
()[
0
]
->
dev_tensor
()).
sync
();
continue
;
}
HostTensorND
to_check
;
to_check
.
copy_from
(
func
->
get_output_vars
()[
0
]
->
dev_tensor
()).
sync
();
MGB_ASSERT_TENSOR_EQ
(
to_check
,
truth
);
}
}
}
// namespace
TEST
(
TestNoCopy
,
IONoCopyEndWithDimshuffle
)
{
test_memory_forward_io_no_copy
(
0
,
{
1
,
7
,
7
,
8
});
}
TEST
(
TestNoCopy
,
IONoCopyEndWithReshape
)
{
test_memory_forward_io_no_copy
(
4
,
{
1
,
8
,
7
,
7
});
}
TEST
(
TestNoCopy
,
IONoCopyEndWithAxisAddRemove
)
{
test_memory_forward_io_no_copy
(
3
,
{
8
,
7
,
7
});
}
TEST
(
TestNoCopy
,
IONoCopyEndWithBroadCast
)
{
test_memory_forward_io_no_copy
(
1
,
{
1
,
8
,
7
,
7
});
}
TEST
(
TestNoCopy
,
IONoCopyEndWithSubtensor
)
{
test_memory_forward_io_no_copy
(
2
,
{
1
,
8
,
7
,
7
});
}
TEST
(
TestNoCopy
,
IONoCopyCorrect
)
{
TEST
(
TestNoCopy
,
IONoCopyCorrect
)
{
auto
test_graph
=
TestGraph
();
auto
test_graph
=
TestGraph
();
auto
compute_graph
=
test_graph
.
m_network
->
graph
;
auto
compute_graph
=
test_graph
.
m_network
->
graph
;
...
...
src/serialization/impl/serializer.cpp
浏览文件 @
896b0193
#include "megbrain/serialization/serializer.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/utility.h"
namespace
{
bool
is_opr_memforward_var
(
mgb
::
VarNode
*
var
)
{
if
(
var
)
{
auto
opr
=
var
->
owner_opr
();
if
(
opr
->
try_cast_final
<
mgb
::
opr
::
Reshape
>
()
||
opr
->
try_cast_final
<
mgb
::
opr
::
Broadcast
>
()
||
opr
->
try_cast_final
<
mgb
::
opr
::
Subtensor
>
()
||
opr
->
try_cast_final
<
mgb
::
opr
::
AxisAddRemove
>
()
||
opr
->
try_cast_final
<
mgb
::
opr
::
Dimshuffle
>
())
{
return
true
;
}
};
return
false
;
}
}
// namespace
namespace
mgb
{
namespace
mgb
{
namespace
serialization
{
namespace
serialization
{
...
@@ -42,6 +60,14 @@ void GraphLoader::LoadResult::graph_compile_ahead() {
...
@@ -42,6 +60,14 @@ void GraphLoader::LoadResult::graph_compile_ahead() {
//! just do basic optimize_for_inference ahead, and replace the var in
//! just do basic optimize_for_inference ahead, and replace the var in
//! LoadResult
//! LoadResult
if
(
graph
->
options
().
force_output_use_user_specified_memory
)
{
if
(
graph
->
options
().
force_output_use_user_specified_memory
)
{
//! if the output var is like dimshuffle, reshape, it maybe memory forward to
//! the output, so add a Copy operator in the end.
for
(
auto
&
var
:
output_var_list
)
{
if
(
is_opr_memforward_var
(
var
.
node
()))
{
std
::
string
name
=
var
.
node
()
->
name
();
var
=
opr
::
Copy
::
make
(
var
,
name
);
}
}
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
auto
new_vars
=
gopt
::
optimize_for_inference
(
output_var_list
,
options
);
auto
new_vars
=
gopt
::
optimize_for_inference
(
output_var_list
,
options
);
output_var_list
=
new_vars
;
output_var_list
=
new_vars
;
...
@@ -62,6 +88,7 @@ void GraphLoader::LoadResult::graph_compile_ahead() {
...
@@ -62,6 +88,7 @@ void GraphLoader::LoadResult::graph_compile_ahead() {
found
,
"can't find var name %s when optimize_for_inference. "
,
found
,
"can't find var name %s when optimize_for_inference. "
,
var
.
node
()
->
cname
());
var
.
node
()
->
cname
());
}
}
output_var_map_id
=
var_map_id
;
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录