Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
62b5452d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
62b5452d
编写于
8月 24, 2022
作者:
W
Wilber
提交者:
GitHub
8月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
conv_eltwiseadd_bn_fuse support fp16 (#45379)
上级
3c14b094
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
39 addition
and
47 deletion
+39
-47
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
+39
-47
未找到文件。
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
浏览文件 @
62b5452d
...
...
@@ -17,8 +17,12 @@
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/data_type.h"
namespace
phi
{
class
DenseTensor
;
...
...
@@ -30,6 +34,23 @@ class Scope;
}
// namespace framework
}
// namespace paddle
namespace
{
template
<
typename
T1
,
typename
T2
>
void
ConvertTensorType
(
paddle
::
framework
::
LoDTensor
*
tensor
)
{
paddle
::
framework
::
Tensor
tmp_tensor
;
tmp_tensor
.
set_type
(
paddle
::
experimental
::
CppTypeToDataType
<
T2
>::
Type
());
tmp_tensor
.
Resize
(
tensor
->
dims
());
auto
*
tmp_data
=
tmp_tensor
.
mutable_data
<
T2
>
(
paddle
::
platform
::
CPUPlace
());
auto
*
data
=
tensor
->
mutable_data
<
T1
>
(
paddle
::
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
tensor
->
numel
();
i
++
)
{
tmp_data
[
i
]
=
static_cast
<
T2
>
(
data
[
i
]);
}
tensor
->
clear
();
paddle
::
framework
::
TensorCopySync
(
tmp_tensor
,
paddle
::
platform
::
CPUPlace
(),
tensor
);
}
}
// namespace
namespace
paddle
{
namespace
framework
{
namespace
ir
{
...
...
@@ -290,19 +311,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
auto
tensor_type
=
conv_weight_tensor
->
dtype
();
if
(
tensor_type
==
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
framework
::
Tensor
weight_float_tensor
;
weight_float_tensor
.
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT32
);
weight_float_tensor
.
Resize
(
conv_weight_tensor
->
dims
());
auto
*
weight_float_data
=
weight_float_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
data
=
conv_weight_tensor
->
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
conv_weight_tensor
->
numel
();
i
++
)
{
weight_float_data
[
i
]
=
static_cast
<
float
>
(
data
[
i
]);
}
conv_weight_tensor
->
clear
();
paddle
::
framework
::
TensorCopySync
(
weight_float_tensor
,
platform
::
CPUPlace
(),
conv_weight_tensor
);
ConvertTensorType
<
float16
,
float
>
(
conv_weight_tensor
);
}
// Get batch norm bias
...
...
@@ -341,40 +350,8 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
conv_type
());
if
(
tensor_type
==
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
{
framework
::
Tensor
weight_float16_tensor
;
weight_float16_tensor
.
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT16
);
weight_float16_tensor
.
Resize
(
conv_weight_tensor
->
dims
());
auto
*
weight_float16_data
=
weight_float16_tensor
.
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
auto
*
data
=
conv_weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
conv_weight_tensor
->
numel
();
i
++
)
{
weight_float16_data
[
i
]
=
static_cast
<
float16
>
(
data
[
i
]);
}
conv_weight_tensor
->
clear
();
paddle
::
framework
::
TensorCopySync
(
weight_float16_tensor
,
platform
::
CPUPlace
(),
conv_weight_tensor
);
}
{
framework
::
Tensor
eltwise_y_in_float16_tensor
;
eltwise_y_in_float16_tensor
.
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT16
);
eltwise_y_in_float16_tensor
.
Resize
(
eltwise_y_in_tensor
->
dims
());
auto
*
eltwise_y_in_float16_data
=
eltwise_y_in_float16_tensor
.
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
auto
*
data
=
eltwise_y_in_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
eltwise_y_in_tensor
->
numel
();
i
++
)
{
eltwise_y_in_float16_data
[
i
]
=
static_cast
<
float16
>
(
data
[
i
]);
}
eltwise_y_in_tensor
->
clear
();
paddle
::
framework
::
TensorCopySync
(
eltwise_y_in_float16_tensor
,
platform
::
CPUPlace
(),
eltwise_y_in_tensor
);
}
ConvertTensorType
<
float
,
float16
>
(
conv_weight_tensor
);
ConvertTensorType
<
float
,
float16
>
(
eltwise_y_in_tensor
);
}
// with MKL-DNN fuse conv+bn into conv with bias
...
...
@@ -612,6 +589,16 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
float
epsilon
=
PADDLE_GET_CONST
(
float
,
batch_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// conv_weight fp16 --> fp32
auto
*
conv_weight_tensor
=
scope
->
FindVar
(
conv_weight
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
tensor_type
=
conv_weight_tensor
->
dtype
();
if
(
tensor_type
==
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
ConvertTensorType
<
float16
,
float
>
(
conv_weight_tensor
);
ConvertTensorType
<
float16
,
float
>
(
eltwise_y_in_tensor
);
}
// if bias is an input to other ops as well then we cannot overwrite it
// so we create separate elementwise Y in nodes
if
(
eltwise_y_in
->
outputs
.
size
()
>
1
)
{
...
...
@@ -666,6 +653,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
conv_type
());
}
if
(
tensor_type
==
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
ConvertTensorType
<
float
,
float16
>
(
conv_weight_tensor
);
ConvertTensorType
<
float
,
float16
>
(
eltwise_y_in_tensor
);
}
// Update the elementwise_add node
eltwise
->
Op
()
->
SetAttr
(
"axis"
,
1
);
eltwise
->
Op
()
->
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
({
bn_out
->
Name
()}));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录