Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
fb7a50b2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
fb7a50b2
编写于
9月 26, 2018
作者:
T
Tomasz Patejko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN conv + elementwise_add fusion: removed commented code. Internal functions marked as static.
test=develop
上级
f6881971
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
3 addition
and
102 deletion
+3
-102
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
...uid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
+3
-102
未找到文件。
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
fb7a50b2
...
...
@@ -22,112 +22,13 @@ namespace framework {
namespace
ir
{
namespace
patterns
{
/*
struct Pattern : public PatternBase {
Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, ""} {}
private:
std::string name_scope() { return name_scope_; }
std::string repr() { return repr_; }
size_t id() { return id_; }
PDPattern* node_pattern() { return pattern; }
public:
std::string node_name(std::string op_name) {
return PDNodeName(name_scope(), repr(), id(), op_name);
}
PDNode* retrieve_node(std::string op_name) {
return node_pattern()->RetrieveNode(node_name(op_name));
}
PDNode* new_node(std::string op_name) {
return node_pattern()->NewNode(node_name(op_name));
}
};
*/
/*
struct Conv {
std::string op_name() const { return "conv2d"; }
std::string input_name() const { return "Input"; }
std::string bias_name() const { return "Bias"; }
std::string filter_name() const { return "Filter"; }
std::string residual_data_name() const { return "ResidualData"; }
std::string output_name() const { return "Output"; }
std::function<PDNode*()> operator()(std::shared_ptr<Pattern> pattern) {
return [&]() -> PDNode* {
auto conv_op = pattern->new_node(op_name())->assert_is_op(op_name());
auto input_var = pattern->new_node(input_name())
->assert_is_op_input(op_name(), input_name());
auto bias_var = pattern->new_node(bias_name())
->assert_is_op_input(op_name(), bias_name());
auto filter_var = pattern->new_node(filter_name())
->assert_is_op_input(op_name(), filter_name());
auto output_var = pattern->new_node(output_name())
->assert_is_op_output(op_name(), output_name());
conv_op->LinksFrom({input_var, bias_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;
};
}
};
struct ElementwiseAdd {
std::string op_name() const { return "elementwise_add"; }
std::string x_name() const { return "X"; }
std::string y_name() const { return "Y"; }
std::string out_name() const { return "Out"; }
std::function<PDNode*(PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
return [&](PDNode* conv_output) -> PDNode* {
auto elementwise_add_op =
pattern->new_node(op_name())->assert_is_op(op_name());
auto x_var =
pattern->new_node(x_name())->assert_is_op_input(op_name(), x_name());
conv_output->assert_is_op_input(op_name(), y_name());
auto out_var = pattern->new_node(out_name())
->AsOutput()
->assert_is_op_output(op_name(), out_name());
elementwise_add_op->LinksFrom({x_var, conv_output});
elementwise_add_op->LinksTo({out_var});
return out_var;
};
}
};
*/
/*
Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
std::shared_ptr<patterns::Pattern> pattern,
const std::string& op_name) {
PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)),
"Node not found for PDNode %s", pattern->node_name(op_name));
Node* var = subgraph.at(pattern->retrieve_node(op_name));
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph");
return var;
}
*/
void
LinkNodes
(
Node
*
from
,
Node
*
to
)
{
static
void
LinkNodes
(
Node
*
from
,
Node
*
to
)
{
from
->
outputs
.
push_back
(
to
);
to
->
inputs
.
push_back
(
from
);
}
template
<
typename
IT
,
typename
FindFunc
,
typename
ReplaceFunc
>
void
ReplaceAllOccurances
(
IT
s
,
IT
e
,
FindFunc
f
,
ReplaceFunc
r
)
{
static
void
ReplaceAllOccurances
(
IT
s
,
IT
e
,
FindFunc
f
,
ReplaceFunc
r
)
{
if
(
s
==
e
)
return
;
auto
it
=
std
::
find_if
(
s
,
e
,
f
);
...
...
@@ -140,7 +41,7 @@ void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
ReplaceAllOccurances
(
it
,
e
,
f
,
r
);
}
void
CorrectGraphEdges
(
Graph
*
graph
,
Node
*
from
,
Node
*
to
)
{
static
void
CorrectGraphEdges
(
Graph
*
graph
,
Node
*
from
,
Node
*
to
)
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
*
graph
))
{
auto
same
=
std
::
find_if
(
std
::
begin
(
node
.
inputs
),
std
::
end
(
node
.
inputs
),
[
from
](
Node
*
n
)
{
return
n
==
from
;
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录