Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3e0bb22c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3e0bb22c
编写于
12月 27, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(lite): load and run supports convert fp32 to fp16 online
GitOrigin-RevId: 05c9a17a00301ced15043cf34284f2d362d2608c
上级
f71dd489
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
230 addition
and
0 deletion
+230
-0
lite/include/lite/network.h
lite/include/lite/network.h
+1
-0
lite/load_and_run/src/helpers/common.h
lite/load_and_run/src/helpers/common.h
+4
-0
lite/load_and_run/src/options/dtype_options.cpp
lite/load_and_run/src/options/dtype_options.cpp
+106
-0
lite/load_and_run/src/options/dtype_options.h
lite/load_and_run/src/options/dtype_options.h
+48
-0
lite/src/mge/network_impl.cpp
lite/src/mge/network_impl.cpp
+3
-0
src/gopt/impl/inference.cpp
src/gopt/impl/inference.cpp
+64
-0
src/gopt/include/megbrain/gopt/inference.h
src/gopt/include/megbrain/gopt/inference.h
+4
-0
未找到文件。
lite/include/lite/network.h
浏览文件 @
3e0bb22c
...
@@ -97,6 +97,7 @@ struct LITE_API Options {
...
@@ -97,6 +97,7 @@ struct LITE_API Options {
bool
enable_nchw4
=
false
;
bool
enable_nchw4
=
false
;
bool
enable_nchw32
=
false
;
bool
enable_nchw32
=
false
;
bool
enable_nchw64
=
false
;
bool
enable_nchw64
=
false
;
bool
enable_f16_io_comp
=
false
;
// convert to fp16
};
};
/**
/**
...
...
lite/load_and_run/src/helpers/common.h
浏览文件 @
3e0bb22c
...
@@ -67,6 +67,10 @@ enum class OptLayoutType {
...
@@ -67,6 +67,10 @@ enum class OptLayoutType {
NHWCD4
=
1
<<
6
,
NHWCD4
=
1
<<
6
,
NCHW44_DOT
=
1
<<
7
NCHW44_DOT
=
1
<<
7
};
};
/*!
* \brief: dtype type for running model optimization
*/
enum
class
OptDTypeType
{
IOC16
=
1
<<
0
};
/**
/**
* base class to story option value
* base class to story option value
*/
*/
...
...
lite/load_and_run/src/options/dtype_options.cpp
0 → 100644
浏览文件 @
3e0bb22c
#include <gflags/gflags.h>
#include "misc.h"
#include "models/model_lite.h"
#include "models/model_mdl.h"
#include "dtype_options.h"
namespace
lar
{
template
<
>
void
DTypeOption
::
config_model_internel
<
ModelLite
>
(
RuntimeParam
&
runtime_param
,
std
::
shared_ptr
<
ModelLite
>
model
)
{
if
(
runtime_param
.
stage
==
RunStage
::
BEFORE_MODEL_LOAD
)
{
#define ENABLE_DTYPE(dtype) \
LITE_LOG("enable " #dtype " optimization"); \
model->get_config().options.enable_##dtype = true; \
break;
switch
(
m_option_flag
)
{
case
OptDTypeType
::
IOC16
:
ENABLE_DTYPE
(
f16_io_comp
)
default:
LITE_THROW
(
"Set unsupport dtype, only --enable-ioc16 is supported. "
"Default case is fp32."
);
break
;
}
#undef ENABLE_DTYPE
}
}
template
<
>
void
DTypeOption
::
config_model_internel
<
ModelMdl
>
(
RuntimeParam
&
runtime_param
,
std
::
shared_ptr
<
ModelMdl
>
model
)
{
if
(
runtime_param
.
stage
==
RunStage
::
BEFORE_MODEL_LOAD
)
{
#define ENABLE_DTYPE(dtype) \
mgb_log("enable " #dtype " optimization"); \
model->get_mdl_config().comp_graph->options().graph_opt.enable_##dtype(); \
break;
switch
(
m_option_flag
)
{
case
OptDTypeType
::
IOC16
:
ENABLE_DTYPE
(
f16_io_comp
)
default:
LITE_THROW
(
"Set unsupport dtype, only --enable-ioc16 is supported. "
"Default case is fp32."
);
break
;
}
#undef ENABLE_DTYPE
}
}
}
// namespace lar
using
namespace
lar
;
bool
DTypeOption
::
m_valid
;
void
DTypeOption
::
update
()
{
m_option_name
=
"dtype"
;
m_option_flag
=
static_cast
<
OptDTypeType
>
(
0
);
m_option
=
{
{
"enable_ioc16"
,
lar
::
Bool
::
make
(
false
)},
};
std
::
static_pointer_cast
<
lar
::
Bool
>
(
m_option
[
"enable_ioc16"
])
->
set_value
(
FLAGS_enable_ioc16
);
}
bool
DTypeOption
::
is_valid
()
{
size_t
valid_flag
=
0
;
if
(
FLAGS_enable_ioc16
)
{
valid_flag
|=
static_cast
<
size_t
>
(
OptDTypeType
::
IOC16
);
}
//! only one flag is valid
bool
ret
=
valid_flag
&&
!
(
valid_flag
&
(
valid_flag
-
1
));
return
ret
|
m_valid
;
};
std
::
shared_ptr
<
OptionBase
>
DTypeOption
::
create_option
()
{
static
std
::
shared_ptr
<
DTypeOption
>
option
(
new
DTypeOption
);
if
(
DTypeOption
::
is_valid
())
{
option
->
update
();
return
std
::
static_pointer_cast
<
OptionBase
>
(
option
);
}
else
{
return
nullptr
;
}
}
void
DTypeOption
::
config_model
(
RuntimeParam
&
runtime_param
,
std
::
shared_ptr
<
ModelBase
>
model
)
{
size_t
valid_flag
=
0
;
if
(
FLAGS_enable_ioc16
||
std
::
static_pointer_cast
<
lar
::
Bool
>
(
m_option
[
"enable_ioc16"
])
->
get_value
())
{
valid_flag
|=
static_cast
<
size_t
>
(
OptDTypeType
::
IOC16
);
}
mgb_throw_if
(
valid_flag
&&
(
valid_flag
&
(
valid_flag
-
1
)),
mgb
::
AssertionError
,
"invalid options of dtype transform 0x%lx"
,
valid_flag
);
m_option_flag
=
static_cast
<
OptDTypeType
>
(
valid_flag
);
CONFIG_MODEL_FUN
;
}
DEFINE_bool
(
enable_ioc16
,
false
,
"enable fp16 dtype optimization!!"
);
REGIST_OPTION_CREATOR
(
dtype
,
lar
::
DTypeOption
::
create_option
);
REGIST_OPTION_VALIDATER
(
dtype
,
lar
::
DTypeOption
::
set_valid
);
\ No newline at end of file
lite/load_and_run/src/options/dtype_options.h
0 → 100644
浏览文件 @
3e0bb22c
#pragma once
#include <gflags/gflags.h>
#include "helpers/common.h"
#include "models/model.h"
#include "option_base.h"
DECLARE_bool
(
enable_ioc16
);
namespace
lar
{
/*!
* \brief: dtype option for optimization
*/
class
DTypeOption
final
:
public
OptionBase
{
public:
//! check the validation of option flag
static
bool
is_valid
();
//! creat options when option is used
static
std
::
shared_ptr
<
OptionBase
>
create_option
();
//! config the model, dispatch configuration for different model implement
void
config_model
(
RuntimeParam
&
runtime_param
,
std
::
shared_ptr
<
ModelBase
>
model
)
override
;
//! get option name
std
::
string
option_name
()
const
override
{
return
m_option_name
;
};
static
void
set_valid
(
bool
val
)
{
m_valid
=
val
;
}
OptionValMap
*
get_option
()
override
{
return
&
m_option
;
}
void
update
()
override
;
private:
//! Constructor
DTypeOption
()
=
default
;
//! configuration for different model implement
template
<
typename
ModelImpl
>
void
config_model_internel
(
RuntimeParam
&
,
std
::
shared_ptr
<
ModelImpl
>
){};
OptDTypeType
m_option_flag
;
std
::
string
m_option_name
;
static
bool
m_valid
;
OptionValMap
m_option
;
};
}
// namespace lar
\ No newline at end of file
lite/src/mge/network_impl.cpp
浏览文件 @
3e0bb22c
...
@@ -97,6 +97,9 @@ void NetworkImplDft::application_config() {
...
@@ -97,6 +97,9 @@ void NetworkImplDft::application_config() {
ConfigOptionLayoutTransform
(
enable_nchw32
);
ConfigOptionLayoutTransform
(
enable_nchw32
);
ConfigOptionLayoutTransform
(
enable_nchw64
);
ConfigOptionLayoutTransform
(
enable_nchw64
);
#undef ConfigOptionLayoutTransform
#undef ConfigOptionLayoutTransform
if
(
m_user_config
->
options
.
enable_f16_io_comp
)
{
options
.
graph_opt
.
enable_f16_io_comp
();
}
if
(
m_user_config
->
has_compression
)
{
if
(
m_user_config
->
has_compression
)
{
m_load_config
.
tensor_value_loader
=
decompressed_tensor_value_loader
;
m_load_config
.
tensor_value_loader
=
decompressed_tensor_value_loader
;
}
}
...
...
src/gopt/impl/inference.cpp
浏览文件 @
3e0bb22c
...
@@ -641,6 +641,22 @@ void ConvertF32ToF16Pass::apply(OptState& state) const {
...
@@ -641,6 +641,22 @@ void ConvertF32ToF16Pass::apply(OptState& state) const {
for
(
size_t
i
=
0
;
i
<
origin_out
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
origin_out
.
size
();
i
++
)
{
rewriter
.
replace_var
(
origin_out
[
i
],
cur_out
[
i
],
nullptr
);
rewriter
.
replace_var
(
origin_out
[
i
],
cur_out
[
i
],
nullptr
);
}
}
}
else
if
(
m_multi_tensor_replace_func
.
find
(
opr
->
dyn_typeinfo
())
!=
m_multi_tensor_replace_func
.
end
())
{
auto
&&
new_inp
=
new_inp_cache
;
new_inp
.
clear
();
new_inp
.
reserve
(
opr
->
input
().
size
());
for
(
auto
i
:
opr
->
input
())
{
new_inp
.
push_back
(
rewriter
.
get_var
(
i
));
}
auto
&&
origin_out
=
opr
->
output
();
auto
&&
cur_out
=
m_multi_tensor_replace_func
.
at
(
opr
->
dyn_typeinfo
())(
opr
,
new_inp
);
mgb_assert
(
origin_out
.
size
()
==
cur_out
.
size
());
for
(
size_t
i
=
0
;
i
<
origin_out
.
size
();
i
++
)
{
rewriter
.
replace_var
(
origin_out
[
i
],
cur_out
[
i
],
nullptr
);
}
}
else
{
}
else
{
rewriter
.
auto_replace_outputs
(
opr
);
rewriter
.
auto_replace_outputs
(
opr
);
}
}
...
@@ -691,6 +707,48 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(bool use_f32_comp
...
@@ -691,6 +707,48 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(bool use_f32_comp
return
opr
;
return
opr
;
};
};
auto
replace_multi_sdt_opr
=
[](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
multi_sdt_opr
=
opr
->
cast_final_safe
<
opr
::
MultipleDeviceTensorHolder
>
();
VarNodeArray
cvt_vars
;
cvt_vars
.
reserve
(
multi_sdt_opr
.
output
().
size
());
for
(
size_t
i
=
0
;
i
<
multi_sdt_opr
.
output
().
size
();
++
i
)
{
if
(
multi_sdt_opr
.
output
(
i
)
->
dtype
()
==
dtype
::
Float32
())
{
cvt_vars
.
append
({
opr
::
TypeCvt
::
make
(
multi_sdt_opr
.
output
(
i
),
dtype
::
Float16
(),
{})
.
node
()});
}
else
{
cvt_vars
.
append
({
multi_sdt_opr
.
output
(
i
)});
}
}
return
cvt_vars
;
};
auto
replace_multi_sdt_with_format_opr
=
[](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
multi_sdt_with_format_opr
=
opr
->
cast_final_safe
<
opr
::
MultipleDeviceTensorWithFormatHolder
>
();
VarNodeArray
cvt_vars
;
cvt_vars
.
reserve
(
multi_sdt_with_format_opr
.
output
().
size
());
for
(
size_t
i
=
0
;
i
<
multi_sdt_with_format_opr
.
output
().
size
();
++
i
)
{
if
(
multi_sdt_with_format_opr
.
output
(
i
)
->
dtype
()
==
dtype
::
Float32
())
{
cvt_vars
.
append
({
opr
::
TypeCvt
::
make
(
multi_sdt_with_format_opr
.
output
(
i
),
dtype
::
Float16
(),
{})
.
node
()});
}
else
{
cvt_vars
.
append
({
multi_sdt_with_format_opr
.
output
(
i
)});
}
}
return
cvt_vars
;
};
auto
replace_imt_opr
=
[](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
auto
replace_imt_opr
=
[](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
same_type
<
opr
::
ImmutableTensor
>
());
mgb_assert
(
opr
->
same_type
<
opr
::
ImmutableTensor
>
());
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
...
@@ -934,6 +992,12 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(bool use_f32_comp
...
@@ -934,6 +992,12 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(bool use_f32_comp
replace_func
[
opr
::
WarpPerspective
::
typeinfo
()]
=
replace_warp_opr
;
replace_func
[
opr
::
WarpPerspective
::
typeinfo
()]
=
replace_warp_opr
;
replace_func
[
opr
::
Remap
::
typeinfo
()]
=
replace_remap_opr
;
replace_func
[
opr
::
Remap
::
typeinfo
()]
=
replace_remap_opr
;
replace_func
[
opr
::
BatchedMatrixMul
::
typeinfo
()]
=
replace_batched_matmul_opr
;
replace_func
[
opr
::
BatchedMatrixMul
::
typeinfo
()]
=
replace_batched_matmul_opr
;
auto
&
tensor_replace_func
=
ret
->
m_multi_tensor_replace_func
;
tensor_replace_func
[
opr
::
MultipleDeviceTensorHolder
::
typeinfo
()]
=
replace_multi_sdt_opr
;
tensor_replace_func
[
opr
::
MultipleDeviceTensorWithFormatHolder
::
typeinfo
()]
=
replace_multi_sdt_with_format_opr
;
return
ret
;
return
ret
;
#endif
#endif
}
}
...
...
src/gopt/include/megbrain/gopt/inference.h
浏览文件 @
3e0bb22c
...
@@ -73,6 +73,10 @@ class ConvertF32ToF16Pass : public Pass {
...
@@ -73,6 +73,10 @@ class ConvertF32ToF16Pass : public Pass {
Typeinfo
*
,
Typeinfo
*
,
thin_function
<
OperatorNodeBase
*
(
OperatorNodeBase
*
,
const
VarNodeArray
&
)
>>
thin_function
<
OperatorNodeBase
*
(
OperatorNodeBase
*
,
const
VarNodeArray
&
)
>>
m_opr_replace_func
;
m_opr_replace_func
;
ThinHashMap
<
Typeinfo
*
,
thin_function
<
VarNodeArray
(
OperatorNodeBase
*
,
const
VarNodeArray
&
)
>>
m_multi_tensor_replace_func
;
VarReplaceCheckFlag
m_var_replace_check_flag
=
VarReplaceCheckFlag
::
CHECK_ALL
;
VarReplaceCheckFlag
m_var_replace_check_flag
=
VarReplaceCheckFlag
::
CHECK_ALL
;
public:
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录