Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fa10524d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fa10524d
编写于
11月 29, 2022
作者:
S
Sławomir Siwek
提交者:
GitHub
11月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
eltwise_div + scale [PHI] (#48484)
上级
9e9b705a
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
43 addition
and
0 deletion
+43
-0
paddle/fluid/operators/ops_extra_info.h
paddle/fluid/operators/ops_extra_info.h
+1
-0
paddle/phi/kernels/onednn/elementwise_kernel.cc
paddle/phi/kernels/onednn/elementwise_kernel.cc
+6
-0
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass.py
...s/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass.py
+36
-0
未找到文件。
paddle/fluid/operators/ops_extra_info.h
浏览文件 @
fa10524d
...
...
@@ -95,6 +95,7 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{
"fuse_activation_alpha"
,
ExtraAttrProperty
::
ONEDNN
},
{
"fuse_activation_beta"
,
ExtraAttrProperty
::
ONEDNN
},
{
"fuse_activation_scale"
,
ExtraAttrProperty
::
ONEDNN
},
{
"fused_output_scale"
,
ExtraAttrProperty
::
ONEDNN
},
{
"fuse_alpha"
,
ExtraAttrProperty
::
ONEDNN
},
{
"fuse_beta"
,
ExtraAttrProperty
::
ONEDNN
},
{
"fuse_relu"
,
ExtraAttrProperty
::
ONEDNN
},
...
...
paddle/phi/kernels/onednn/elementwise_kernel.cc
浏览文件 @
fa10524d
...
...
@@ -43,6 +43,12 @@ void ElementwiseKernel(const OneDNNContext& dev_ctx,
dnnl
::
post_ops
post_operations
;
funcs
::
AppendActivation
(
dev_ctx
,
post_operations
);
if
(
dev_ctx
.
HasDnnAttr
(
"fused_output_scale"
))
{
float
scale_alpha
=
PADDLE_GET_CONST
(
float
,
dev_ctx
.
GetDnnAttr
(
"fused_output_scale"
));
post_operations
.
append_eltwise
(
1.0
,
dnnl
::
algorithm
::
eltwise_linear
,
scale_alpha
,
0.0
f
);
}
auto
*
non_const_x
=
&
x
;
auto
*
non_const_y
=
&
y
;
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass.py
浏览文件 @
fa10524d
...
...
@@ -356,6 +356,42 @@ class ElementwiseActivationMkldnnFusePassTest_Mul_Sigmoid(
self
.
act
=
paddle
.
nn
.
functional
.
sigmoid
class
ElementwiseScaleOneDNNFusePassTest_Add
(
ElementwiseActivationMkldnnFusePassTest
):
def
set_params
(
self
):
self
.
operand
=
fluid
.
layers
.
elementwise_add
self
.
act_alpha
=
0.6
self
.
act
=
paddle
.
scale
class
ElementwiseScaleOneDNNFusePassTest_Sub
(
ElementwiseActivationMkldnnFusePassTest
):
def
set_params
(
self
):
self
.
operand
=
fluid
.
layers
.
elementwise_sub
self
.
act_alpha
=
0.6
self
.
act
=
paddle
.
scale
class
ElementwiseScaleOneDNNFusePassTest_Mul
(
ElementwiseActivationMkldnnFusePassTest
):
def
set_params
(
self
):
self
.
operand
=
fluid
.
layers
.
elementwise_mul
self
.
act_alpha
=
0.6
self
.
act
=
paddle
.
scale
class
ElementwiseScaleOneDNNFusePassTest_Div
(
ElementwiseActivationMkldnnFusePassTest
):
def
set_params
(
self
):
self
.
operand
=
fluid
.
layers
.
elementwise_div
self
.
act_alpha
=
0.6
self
.
act
=
paddle
.
scale
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录