Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9ce343f8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9ce343f8
编写于
9月 10, 2018
作者:
T
Tomasz Patejko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN conv + elementwise_add fusion: initial implementation of patterns
上级
da722d6d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
198 addition
and
0 deletion
+198
-0
paddle/fluid/framework/ir/mkldnn_conv_elementwise_add_fuse_pass.cc
...uid/framework/ir/mkldnn_conv_elementwise_add_fuse_pass.cc
+174
-0
paddle/fluid/framework/ir/mkldnn_conv_elementwise_add_fuse_pass.h
...luid/framework/ir/mkldnn_conv_elementwise_add_fuse_pass.h
+24
-0
未找到文件。
paddle/fluid/framework/ir/mkldnn_conv_elementwise_add_fuse_pass.cc
0 → 100644
浏览文件 @
9ce343f8
#include "paddle/fluid/framework/ir/mkldnn_conv_elementwise_add_fuse_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
PatternNode
{
PatternNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
repr
,
size_t
id
)
:
nodeName
{
PDNodeName
(
name_scope
,
repr
,
id
,
name
)}
,
node
{
pattern
->
RetrieveNode
(
nodeName
)
{
}
std
::
string
name
()
{
return
nodeName
};
PDNode
*
node
()
{
return
node
};
private:
std
::
string
nodeName
;
PDNode
*
node
;
};
/*
struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, "conv"}
, conv{pattern, "conv", name_scope_, repr_, id_}
, input{pattern, "Input", name_scope_, repr_, id_}
, filter{pattern, "Filter", name_scope_, repr_, id_}
, output{pattern, "Output", node_scope_, repr_ id_}
{ }
private:
PatternNode conv;
PatternNode input;
PatternNode filter;
PatternNode output;
public:
PDNode* operator()() {
auto conv_op = pattern->NewNode(conv.name())
->assert_is_op("conv2d");
auto input_var = pattern->NewNode(input.name())
->AsInput()
->assert_is_op_input(conv.name());
auto filter_var = pattern->NewNode(filter.name())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input(conv.name());
auto output_var = patterh->NewNode(output.name())
->AsOutput()
->assert_is_op_output(conv.name());
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var};
return output_var;
}
};
*/
struct
Conv
:
public
PatternBase
{
Conv
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
{
pattern
,
name_scope
,
"conv"
}
{
}
std
::
string
conv_name
()
{
return
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"conv2d"
);
}
PDNode
*
conv_node
()
{
return
pattern
->
RetrieveNode
(
conv_name
());
}
std
::
string
input_name
()
{
return
PDNodeName
(
name_scope
,
repr_
,
id_
,
"Input"
);
}
PDNode
*
input_node
()
{
return
pattern
->
RetrieveNode
(
input_name
());
}
std
::
string
filter_name
()
{
return
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"Filter"
);
}
PDNode
*
filter_node
()
{
return
pattern
->
RetrieveNode
(
filter_name
());
}
std
::
string
output_name
()
{
return
PDNodeName
(
name_scope
,
repr_
,
id_
,
"Output"
);
}
PDNode
*
output_node
()
{
return
pattern
->
RetrieveNode
(
output_name
());
}
PDNode
*
operator
()()
{
auto
conv_op
=
pattern
->
NewNode
(
conv_name
())
->
assert_is_op
(
"conv2d"
);
auto
input_var
=
pattern
->
NewNode
(
input_name
())
->
AsInput
()
->
assert_is_op_input
(
conv_name
());
auto
filter_var
=
pattern
->
NewNode
(
filter_name
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
conv_name
());
auto
output_var
=
patterh
->
NewNode
(
output_name
())
->
AsOutput
()
->
assert_is_op_output
(
conv_name
());
conv_op
->
LinksFrom
({
input_var
,
filter_var
});
conv_op
->
LinksTo
({
output_var
};
return
output_var
;
}
};
struct
ElementwiseAdd
:
public
PatternBase
{
Conv
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
{
pattern
,
name_scope
,
"elementwise_add"
}
{
}
std
::
string
elementwise_add_name
()
{
return
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"elementwise_add"
);
}
PDNode
*
elementwise_add_node
()
{
return
pattern
->
RetrieveNode
(
elementwise_add_name
());
}
std
::
string
x_name
()
{
return
PDNodeName
(
name_scope
,
repr_
,
id_
,
"X"
);
}
PDNode
*
x_node
()
{
return
pattern
->
RetrieveNode
(
x_name
());
}
std
::
string
y_name
()
{
return
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"Y"
);
}
PDNode
*
y_node
()
{
return
pattern
->
RetrieveNode
(
y_name
());
}
std
::
string
out_name
()
{
return
PDNodeName
(
name_scope
,
repr_
,
id_
,
"Out"
);
}
PDNode
*
out_node
()
{
return
pattern
->
RetrieveNode
(
out_name
());
}
PDNode
*
operator
()(
PDNode
*
conv_output
)
{
auto
elementwise_add_op
=
pattern
->
NewNode
(
conv_name
())
->
assert_is_op
(
"elementwise_add"
);
auto
x_var
=
pattern
->
NewNode
(
x_name
())
->
AsInput
()
->
assert_is_op_input
(
elementwise_add_name
());
conv_output
->
assert_is_op_input
(
elementwise_add_name
(),
y_name
());
// auto y_var = pattern->NewNode(y_name())
// ->AsInput()
// ->assert_is_op_input(elementwise_add_name());
auto
out_var
=
pattern
->
NewNode
(
out_name
())
->
AsOutput
()
->
assert_is_op_output
(
elementwise_add_name
());
conv_op
->
LinksFrom
({
x_var
,
conv_output
});
conv_op
->
LinksTo
({
out_var
};
return
out_var
;
}
};
}
// namespace patterns
using
graph_ptr
=
std
::
unique_ptr
<
ir
::
Graph
>
;
graph_ptr
MKLDNNConvElementwiseAddFusePass
::
ApplyImpl
(
graph_ptr
)
const
{
FusePassBase
::
Init
(
"mkldnn_conv_elementwise_add_fuse"
,
graph
.
get
());
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
patterns
::
Conv
conv_pattern
(
pattern
,
name_scope_
);
auto
conv_output
=
conv_pattern
();
conv_output
->
AsIntermediate
();
patterns
::
ElementwiseAdd
elementwise_add_pattern
(
pattern
,
name_scope_
);
auto
elementwis_add_output
=
elementwise_add_pattern
(
conv_output
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/mkldnn_conv_elementwise_add_fuse_pass.h
0 → 100644
浏览文件 @
9ce343f8
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
MKLDNNConvElementwiseAddFusePass
:
public
FusePassBase
{
public:
virtual
~
FCGRUFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"mkldnn_conv_elementwise_add_fuse"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录