Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5ebc9d50
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5ebc9d50
编写于
3月 07, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(pylite): fix lite global layout transform and fast run conflict error
GitOrigin-RevId: 910c8da19f3c9973a088b70b782e91f6b366d4f9
上级
49d92d9c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
76 addition
and
23 deletion
+76
-23
lite/load_and_run/src/models/model_lite.h
lite/load_and_run/src/models/model_lite.h
+0
-1
lite/pylite/test/test_network.py
lite/pylite/test/test_network.py
+26
-0
lite/pylite/test/test_network_device.py
lite/pylite/test/test_network_device.py
+28
-0
lite/src/mge/network_impl.cpp
lite/src/mge/network_impl.cpp
+7
-7
src/gopt/impl/global_layout_transform/opr_format_modifier.cpp
...gopt/impl/global_layout_transform/opr_format_modifier.cpp
+15
-15
未找到文件。
lite/load_and_run/src/models/model_lite.h
浏览文件 @
5ebc9d50
...
...
@@ -40,7 +40,6 @@ public:
void
wait
()
override
;
//! enable global layout transform
void
set_layout_transform
(
bool
state
)
{
enable_layout_transform
=
state
;
}
//! get the network of lite model
...
...
lite/pylite/test/test_network.py
浏览文件 @
5ebc9d50
...
...
@@ -468,3 +468,29 @@ class TestNetwork(TestShuffleNet):
fi
=
open
(
"./model_afer_layoutTrans.mgb"
,
"r"
)
fi
.
close
()
os
.
remove
(
"./model_afer_layoutTrans.mgb"
)
def
test_fast_run_and_global_layout_transform
(
self
):
config_
=
LiteConfig
()
network
=
LiteNetwork
(
config_
)
fast_run_cache
=
"./algo_cache"
global_layout_transform_model
=
"./model_afer_layoutTrans.mgb"
network
.
set_network_algo_policy
(
LiteAlgoSelectStrategy
.
LITE_ALGO_PROFILE
|
LiteAlgoSelectStrategy
.
LITE_ALGO_OPTIMIZED
)
network
.
enable_global_layout_transform
()
network
.
load
(
self
.
model_path
)
self
.
do_forward
(
network
)
network
.
dump_layout_transform_model
(
global_layout_transform_model
)
LiteGlobal
.
dump_persistent_cache
(
fast_run_cache
)
fi
=
open
(
fast_run_cache
,
"r"
)
fi
.
close
()
fi
=
open
(
global_layout_transform_model
,
"r"
)
fi
.
close
()
LiteGlobal
.
set_persistent_cache
(
path
=
fast_run_cache
)
self
.
do_forward
(
network
)
os
.
remove
(
fast_run_cache
)
os
.
remove
(
global_layout_transform_model
)
lite/pylite/test/test_network_device.py
浏览文件 @
5ebc9d50
...
...
@@ -293,3 +293,31 @@ class TestNetwork(TestShuffleNetCuda):
fi
=
open
(
"./model_afer_layoutTrans.mgb"
,
"r"
)
fi
.
close
()
os
.
remove
(
"./model_afer_layoutTrans.mgb"
)
@
require_cuda
()
def
test_fast_run_and_global_layout_transform
(
self
):
config_
=
LiteConfig
()
config_
.
device_type
=
LiteDeviceType
.
LITE_CUDA
network
=
LiteNetwork
(
config_
)
fast_run_cache
=
"./algo_cache"
global_layout_transform_model
=
"./model_afer_layoutTrans.mgb"
network
.
set_network_algo_policy
(
LiteAlgoSelectStrategy
.
LITE_ALGO_PROFILE
|
LiteAlgoSelectStrategy
.
LITE_ALGO_OPTIMIZED
)
network
.
enable_global_layout_transform
()
network
.
load
(
self
.
model_path
)
self
.
do_forward
(
network
)
network
.
dump_layout_transform_model
(
global_layout_transform_model
)
LiteGlobal
.
dump_persistent_cache
(
fast_run_cache
)
fi
=
open
(
fast_run_cache
,
"r"
)
fi
.
close
()
fi
=
open
(
global_layout_transform_model
,
"r"
)
fi
.
close
()
LiteGlobal
.
set_persistent_cache
(
path
=
fast_run_cache
)
self
.
do_forward
(
network
)
os
.
remove
(
fast_run_cache
)
os
.
remove
(
global_layout_transform_model
)
lite/src/mge/network_impl.cpp
浏览文件 @
5ebc9d50
...
...
@@ -422,6 +422,8 @@ void NetworkImplDft::load_model(
m_load_result
=
m_loader
->
load
(
m_load_config
,
true
);
modify_exection_policy
();
global_layout_transform
();
adapt_option_valid
();
...
...
@@ -436,7 +438,6 @@ void NetworkImplDft::load_model(
}
void
NetworkImplDft
::
compile_graph
()
{
modify_exection_policy
();
replace_dev_input_pass
();
make_output_spec
();
m_execute_func
=
m_load_result
.
graph_compile
(
m_output_spec
);
...
...
@@ -793,7 +794,8 @@ void NetworkImplDft::set_network_algo_policy(
if
(
static_cast
<
uint32_t
>
(
strategy
)
&
LiteAlgoSelectStrategy
::
LITE_ALGO_OPTIMIZED
)
{
dst_strategy
=
dst_strategy
|
S
::
OPTIMIZED
;
}
m_execution_policy
=
dst_strategy
;
if
(
static_cast
<
uint32_t
>
(
dst_strategy
)
!=
0
)
m_execution_policy
=
dst_strategy
;
auto
&&
fast_run_config
=
m_load_config
.
comp_graph
->
options
().
fast_run_config
;
fast_run_config
.
binary_equal_between_batch
=
binary_equal_between_batch
;
...
...
@@ -808,12 +810,10 @@ void NetworkImplDft::set_network_algo_policy(
}
void
NetworkImplDft
::
modify_exection_policy
()
{
mgb
::
SymbolVarArray
vars
;
for
(
auto
i
:
m_output_spec
)
{
vars
.
push_back
(
i
.
first
);
}
if
(
static_cast
<
uint32_t
>
(
m_execution_policy
)
!=
0
)
auto
&
vars
=
m_load_result
.
output_var_list
;
if
(
static_cast
<
uint32_t
>
(
m_execution_policy
)
!=
0
)
{
mgb
::
gopt
::
modify_opr_algo_strategy_inplace
(
vars
,
m_execution_policy
);
}
}
//! set opr algorithm selection strategy in the network
...
...
src/gopt/impl/global_layout_transform/opr_format_modifier.cpp
浏览文件 @
5ebc9d50
...
...
@@ -289,21 +289,21 @@ namespace intl {
template
<
typename
Opr
>
struct
OprFormatModifier
;
#define INST(_Opr) \
template <> \
struct OprFormatModifier<_Opr> { \
using OprFormat = typename _Opr::Param::Format; \
static VarNode* make( \
OprFormat opr_format, const VarNodeArray& i, \
const cg::OperatorNodeBase* opr_) { \
MIDOUT_B(_Opr) \
auto&& opr = opr_->cast_final_safe<_Opr>(); \
auto param = opr.param(); \
param.format = opr_format; \
return OprWithPolicyMaker<_Opr>::make( \
i, param, opr.execution_policy(), opr.config()); \
MIDOUT_E \
} \
#define INST(_Opr)
\
template <>
\
struct OprFormatModifier<_Opr> {
\
using OprFormat = typename _Opr::Param::Format;
\
static VarNode* make(
\
OprFormat opr_format, const VarNodeArray& i,
\
const cg::OperatorNodeBase* opr_) {
\
MIDOUT_B(_Opr)
\
auto&& opr = opr_->cast_final_safe<_Opr>();
\
auto param = opr.param();
\
param.format = opr_format;
\
return OprWithPolicyMaker<_Opr>::make(
\
i, param, opr.execution_policy
_transient
(), opr.config()); \
MIDOUT_E
\
}
\
};
INST
(
Convolution
);
INST
(
ConvBiasForward
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录