Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1be6bf45
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1be6bf45
编写于
8月 12, 2020
作者:
Y
Yiqun Liu
提交者:
GitHub
8月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add assign to fusion_group and enhance inplace execution in fusion_group. (#26121)
上级
b2034c28
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
231 addition
and
187 deletion
+231
-187
paddle/fluid/framework/ir/fusion_group/code_generator.cc
paddle/fluid/framework/ir/fusion_group/code_generator.cc
+60
-37
paddle/fluid/framework/ir/fusion_group/code_generator.h
paddle/fluid/framework/ir/fusion_group/code_generator.h
+1
-1
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
...e/fluid/framework/ir/fusion_group/code_generator_helper.h
+6
-6
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
.../fluid/framework/ir/fusion_group/code_generator_tester.cc
+2
-3
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
...d/framework/ir/fusion_group/elementwise_group_detector.cc
+1
-1
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
+29
-39
paddle/fluid/framework/ir/fusion_group/operation.cc
paddle/fluid/framework/ir/fusion_group/operation.cc
+13
-9
paddle/fluid/framework/ir/fusion_group/subgraph.h
paddle/fluid/framework/ir/fusion_group/subgraph.h
+62
-40
paddle/fluid/operators/fused/fusion_group_op.cc
paddle/fluid/operators/fused/fusion_group_op.cc
+18
-9
paddle/fluid/operators/fused/fusion_group_op.h
paddle/fluid/operators/fused/fusion_group_op.h
+16
-16
paddle/fluid/operators/fused/fusion_group_op_test.cc
paddle/fluid/operators/fused/fusion_group_op_test.cc
+12
-14
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
...dle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
+11
-12
未找到文件。
paddle/fluid/framework/ir/fusion_group/code_generator.cc
浏览文件 @
1be6bf45
...
@@ -68,11 +68,35 @@ static bool HasInput(Node* n, std::string name) {
...
@@ -68,11 +68,35 @@ static bool HasInput(Node* n, std::string name) {
return
input_names_set
.
find
(
name
)
!=
input_names_set
.
end
();
return
input_names_set
.
find
(
name
)
!=
input_names_set
.
end
();
}
}
static
Node
*
GetInputVar
(
Node
*
n
,
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_EQ
(
n
&&
n
->
IsOp
()
&&
n
->
Op
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Expected node %p to be an operator node."
,
n
));
for
(
auto
*
in
:
n
->
inputs
)
{
if
(
in
->
Name
()
==
name
)
{
return
in
;
}
}
return
nullptr
;
}
static
Node
*
GetOutputVar
(
Node
*
n
,
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_EQ
(
n
&&
n
->
IsOp
()
&&
n
->
Op
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Expected node %p to be an operator node."
,
n
));
for
(
auto
*
out
:
n
->
outputs
)
{
if
(
out
->
Name
()
==
name
)
{
return
out
;
}
}
return
nullptr
;
}
std
::
vector
<
OperationExpression
>
CodeGenerator
::
ConvertToExpressions
(
std
::
vector
<
OperationExpression
>
CodeGenerator
::
ConvertToExpressions
(
SubGraph
*
subgraph
)
{
SubGraph
*
subgraph
)
{
std
::
unordered_map
<
std
::
string
,
int
>
var_ids
=
EncodeVarNodes
(
subgraph
);
std
::
unordered_map
<
Node
*
,
int
>
var_ids
=
EncodeVarNodes
(
subgraph
);
std
::
vector
<
Node
*>
intermediate_out_nodes
=
std
::
unordered_set
<
Node
*>
intermediate_out_vars_set
=
subgraph
->
GetIntermediateOutVarNodes
();
subgraph
->
GetIntermediateOutVarNodes
Set
();
std
::
vector
<
OperationExpression
>
expressions
;
std
::
vector
<
OperationExpression
>
expressions
;
for
(
auto
*
node
:
subgraph
->
SortedNodes
())
{
for
(
auto
*
node
:
subgraph
->
SortedNodes
())
{
if
(
node
&&
node
->
IsOp
()
&&
node
->
Op
())
{
if
(
node
&&
node
->
IsOp
()
&&
node
->
Op
())
{
...
@@ -92,11 +116,12 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
...
@@ -92,11 +116,12 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if
((
HasInput
(
node
,
name
)
&&
op
->
Input
(
name
).
size
()
>=
1U
))
{
if
((
HasInput
(
node
,
name
)
&&
op
->
Input
(
name
).
size
()
>=
1U
))
{
for
(
size_t
i
=
0
;
i
<
op
->
Input
(
name
).
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op
->
Input
(
name
).
size
();
i
++
)
{
Node
*
input_var
=
GetInputVar
(
node
,
op
->
Input
(
name
)[
i
]);
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
var_ids
.
find
(
op
->
Input
(
name
)[
i
]
),
var_ids
.
end
(),
var_ids
.
find
(
input_var
),
var_ids
.
end
(),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Input(%s) of operation %s is not set."
,
name
,
op
->
Type
()));
"Input(%s) of operation %s is not set."
,
name
,
op
->
Type
()));
input_ids
.
push_back
(
var_ids
[
op
->
Input
(
name
)[
i
]
]);
input_ids
.
push_back
(
var_ids
[
input_var
]);
}
}
}
else
{
}
else
{
input_ids
.
push_back
(
-
1
);
input_ids
.
push_back
(
-
1
);
...
@@ -106,31 +131,29 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
...
@@ -106,31 +131,29 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// Output ids should be set in fixed order, like:
// Output ids should be set in fixed order, like:
// - dx, dy in backward operations
// - dx, dy in backward operations
std
::
vector
<
int
>
output_ids
;
std
::
vector
<
int
>
output_ids
;
std
::
vector
<
int
>
intermediate_output_ids
;
std
::
vector
<
std
::
string
>
output_names
=
std
::
vector
<
std
::
string
>
output_names
=
OperationMap
::
Instance
().
Get
(
op
->
Type
()).
output_names
;
OperationMap
::
Instance
().
Get
(
op
->
Type
()).
output_names
;
std
::
unordered_map
<
int
,
bool
>
intermediate_state
;
for
(
auto
&
name
:
output_names
)
{
for
(
auto
&
name
:
output_names
)
{
Node
*
output_var
=
GetOutputVar
(
node
,
op
->
Output
(
name
)[
0
]);
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
var_ids
.
find
(
o
p
->
Output
(
name
)[
0
]
),
var_ids
.
end
(),
var_ids
.
find
(
o
utput_var
),
var_ids
.
end
(),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Output(%s) of operation %s is not set."
,
name
,
op
->
Type
()));
"Output(%s) of operation %s is not set."
,
name
,
op
->
Type
()));
output_ids
.
push_back
(
var_ids
[
op
->
Output
(
name
)[
0
]]);
output_ids
.
push_back
(
var_ids
[
output_var
]);
bool
enable_intermediate
=
false
;
if
(
!
subgraph
->
SaveIntermediateOut
()
&&
for
(
auto
*
n
:
intermediate_out_nodes
)
{
intermediate_out_vars_set
.
find
(
output_var
)
!=
if
(
n
->
Name
()
==
op
->
Output
(
name
)[
0
])
{
intermediate_out_vars_set
.
end
())
{
enable_intermediate
=
true
;
intermediate_output_ids
.
push_back
(
var_ids
[
output_var
]);
break
;
}
}
}
intermediate_state
[
var_ids
[
op
->
Output
(
name
)[
0
]]]
=
enable_intermediate
;
}
}
std
::
string
lhs_type
=
ExtractDataType
(
node
->
outputs
);
std
::
string
lhs_type
=
ExtractDataType
(
node
->
outputs
);
std
::
string
rhs_type
=
ExtractDataType
(
node
->
inputs
);
std
::
string
rhs_type
=
ExtractDataType
(
node
->
inputs
);
auto
expression
=
auto
expression
=
OperationExpression
(
node
->
Name
(),
input_ids
,
output_ids
,
rhs_type
,
OperationExpression
(
node
->
Name
(),
input_ids
,
output_ids
,
rhs_type
,
lhs_type
,
intermediate_
state
);
lhs_type
,
intermediate_
output_ids
);
expression
.
SetAttr
(
attr
);
expression
.
SetAttr
(
attr
);
expressions
.
push_back
(
expression
);
expressions
.
push_back
(
expression
);
}
}
...
@@ -146,17 +169,18 @@ std::string CodeGenerator::Generate(
...
@@ -146,17 +169,18 @@ std::string CodeGenerator::Generate(
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std
::
set
<
int
>
input_ids
=
std
::
move
(
DistilInputIds
(
expressions
));
std
::
set
<
int
>
input_ids
=
std
::
move
(
DistilInputIds
(
expressions
));
std
::
set
<
int
>
output_ids
=
std
::
move
(
DistilOutputIds
(
expressions
));
std
::
set
<
int
>
output_ids
=
std
::
move
(
DistilOutputIds
(
expressions
));
std
::
set
<
int
>
intermediate_ids
=
std
::
set
<
int
>
intermediate_
output_
ids
=
std
::
move
(
DistilIntermediateIds
(
expressions
));
std
::
move
(
DistilIntermediateIds
(
expressions
));
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
=
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
=
std
::
move
(
DistilDtypes
(
expressions
));
std
::
move
(
DistilDtypes
(
expressions
));
TemplateVariable
template_var
;
TemplateVariable
template_var
;
template_var
.
Add
(
"func_name"
,
func_name
);
template_var
.
Add
(
"func_name"
,
func_name
);
template_var
.
Add
(
"parameters"
,
EmitParameters
(
input_ids
,
output_ids
,
template_var
.
Add
(
intermediate_ids
,
dtypes
));
"parameters"
,
EmitParameters
(
input_ids
,
output_ids
,
intermediate_output_ids
,
dtypes
));
template_var
.
Add
(
"compute_body"
,
template_var
.
Add
(
"compute_body"
,
EmitComputeBody
(
expressions
,
input_ids
,
output_ids
,
EmitComputeBody
(
expressions
,
input_ids
,
output_ids
,
intermediate_ids
,
dtypes
));
intermediate_
output_
ids
,
dtypes
));
std
::
set
<
std
::
string
>
all_dtype
;
std
::
set
<
std
::
string
>
all_dtype
;
for
(
const
auto
&
type
:
dtypes
)
{
for
(
const
auto
&
type
:
dtypes
)
{
...
@@ -204,18 +228,14 @@ std::set<int> CodeGenerator::DistilOutputIds(
...
@@ -204,18 +228,14 @@ std::set<int> CodeGenerator::DistilOutputIds(
std
::
set
<
int
>
CodeGenerator
::
DistilIntermediateIds
(
std
::
set
<
int
>
CodeGenerator
::
DistilIntermediateIds
(
const
std
::
vector
<
OperationExpression
>&
expressions
)
{
const
std
::
vector
<
OperationExpression
>&
expressions
)
{
std
::
set
<
int
>
intermediate_ids
;
std
::
set
<
int
>
intermediate_
output_
ids
;
// Use std::set to remove the reptead id and get a ordered list.
// Use std::set to remove the reptead id and get a ordered list.
for
(
size_t
i
=
0
;
i
<
expressions
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
expressions
.
size
();
i
++
)
{
for
(
auto
id
:
expressions
[
i
].
GetOutputIds
())
{
for
(
auto
id
:
expressions
[
i
].
GetIntermediateOutputIds
())
{
auto
intermediate_state
=
expressions
[
i
].
GetIntermediateState
();
intermediate_output_ids
.
insert
(
id
);
if
(
intermediate_state
.
find
(
id
)
!=
intermediate_state
.
end
()
&&
intermediate_state
[
id
])
{
intermediate_ids
.
insert
(
id
);
}
}
}
}
}
return
intermediate_ids
;
return
intermediate_
output_
ids
;
}
}
std
::
unordered_map
<
int
,
std
::
string
>
CodeGenerator
::
DistilDtypes
(
std
::
unordered_map
<
int
,
std
::
string
>
CodeGenerator
::
DistilDtypes
(
...
@@ -316,26 +336,29 @@ std::string CodeGenerator::EmitComputeBody(
...
@@ -316,26 +336,29 @@ std::string CodeGenerator::EmitComputeBody(
return
load
.
str
()
+
compute
.
str
()
+
store
.
str
();
return
load
.
str
()
+
compute
.
str
()
+
store
.
str
();
}
}
std
::
unordered_map
<
std
::
string
,
int
>
CodeGenerator
::
EncodeVarNodes
(
std
::
unordered_map
<
Node
*
,
int
>
CodeGenerator
::
EncodeVarNodes
(
SubGraph
*
subgraph
)
{
SubGraph
*
subgraph
)
{
const
auto
&
input_var_nodes
=
subgraph
->
GetInputVarNodes
();
const
auto
&
input_var_nodes
=
subgraph
->
GetInputVarNodes
();
const
auto
&
output_var_nodes
=
subgraph
->
GetOutputVarNodes
();
// Encode all var nodes, including intermediate output var nodes.
const
auto
&
output_var_nodes
=
subgraph
->
GetOutputVarNodes
(
true
);
int
id
=
0
;
int
id
=
0
;
std
::
unordered_map
<
std
::
string
,
int
>
var_ids
;
std
::
unordered_map
<
Node
*
,
int
>
var_ids
;
// Numbering input vars.
// Numbering input vars.
for
(
auto
*
in
:
input_var_nodes
)
{
for
(
auto
*
in
:
input_var_nodes
)
{
VLOG
(
3
)
<<
"Encoding input names:"
<<
in
->
Name
()
<<
", id:"
<<
id
;
VLOG
(
3
)
<<
"Encoding input names:"
<<
in
->
Name
()
<<
"("
<<
in
if
(
var_ids
.
find
(
in
->
Name
())
==
var_ids
.
end
())
{
<<
"), id:"
<<
id
;
var_ids
[
in
->
Name
()]
=
id
++
;
if
(
var_ids
.
find
(
in
)
==
var_ids
.
end
())
{
var_ids
[
in
]
=
id
++
;
}
}
}
}
// Encoding output vars.
// Encoding output vars.
for
(
auto
*
out
:
output_var_nodes
)
{
for
(
auto
*
out
:
output_var_nodes
)
{
VLOG
(
3
)
<<
"Ecoding output names:"
<<
out
->
Name
()
<<
", id:"
<<
id
;
VLOG
(
3
)
<<
"Ecoding output names:"
<<
out
->
Name
()
<<
"("
<<
out
if
(
var_ids
.
find
(
out
->
Name
())
==
var_ids
.
end
())
{
<<
"), id:"
<<
id
;
var_ids
[
out
->
Name
()]
=
id
++
;
if
(
var_ids
.
find
(
out
)
==
var_ids
.
end
())
{
var_ids
[
out
]
=
id
++
;
}
}
}
}
return
var_ids
;
return
var_ids
;
...
...
paddle/fluid/framework/ir/fusion_group/code_generator.h
浏览文件 @
1be6bf45
...
@@ -61,7 +61,7 @@ class CodeGenerator {
...
@@ -61,7 +61,7 @@ class CodeGenerator {
const
std
::
unordered_map
<
int
,
std
::
string
>&
dtypes
)
const
;
const
std
::
unordered_map
<
int
,
std
::
string
>&
dtypes
)
const
;
// Encode all var nodes in the subgraph with an unique number.
// Encode all var nodes in the subgraph with an unique number.
std
::
unordered_map
<
std
::
string
,
int
>
EncodeVarNodes
(
SubGraph
*
subgraph
);
std
::
unordered_map
<
Node
*
,
int
>
EncodeVarNodes
(
SubGraph
*
subgraph
);
private:
private:
std
::
vector
<
CodeTemplate
>
code_templates_
;
std
::
vector
<
CodeTemplate
>
code_templates_
;
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
浏览文件 @
1be6bf45
...
@@ -48,20 +48,20 @@ class OperationExpression {
...
@@ -48,20 +48,20 @@ class OperationExpression {
std
::
string
op_type
,
const
std
::
vector
<
int
>&
input_ids
,
std
::
string
op_type
,
const
std
::
vector
<
int
>&
input_ids
,
const
std
::
vector
<
int
>&
output_ids
,
std
::
string
rhs_type
,
const
std
::
vector
<
int
>&
output_ids
,
std
::
string
rhs_type
,
std
::
string
lhs_type
,
std
::
string
lhs_type
,
const
std
::
unordered_map
<
int
,
bool
>&
intermediate_state
=
{})
const
std
::
vector
<
int
>&
intermediate_output_ids
=
{})
:
op_type_
(
op_type
),
:
op_type_
(
op_type
),
input_ids_
(
input_ids
),
input_ids_
(
input_ids
),
output_ids_
(
output_ids
),
output_ids_
(
output_ids
),
rhs_type_
(
rhs_type
),
rhs_type_
(
rhs_type
),
lhs_type_
(
lhs_type
),
lhs_type_
(
lhs_type
),
intermediate_
state_
(
intermediate_state
)
{}
intermediate_
output_ids_
(
intermediate_output_ids
)
{}
std
::
string
GetOpType
()
const
{
return
op_type_
;
}
std
::
string
GetOpType
()
const
{
return
op_type_
;
}
std
::
unordered_map
<
int
,
bool
>
GetIntermediateState
()
const
{
return
intermediate_state_
;
}
std
::
vector
<
int
>
GetInputIds
()
const
{
return
input_ids_
;
}
std
::
vector
<
int
>
GetInputIds
()
const
{
return
input_ids_
;
}
std
::
vector
<
int
>
GetOutputIds
()
const
{
return
output_ids_
;
}
std
::
vector
<
int
>
GetOutputIds
()
const
{
return
output_ids_
;
}
std
::
vector
<
int
>
GetIntermediateOutputIds
()
const
{
return
intermediate_output_ids_
;
}
std
::
string
GetRHSType
()
const
{
return
rhs_type_
;
}
std
::
string
GetRHSType
()
const
{
return
rhs_type_
;
}
std
::
string
GetLHSType
()
const
{
return
lhs_type_
;
}
std
::
string
GetLHSType
()
const
{
return
lhs_type_
;
}
void
SetAttr
(
AttributeMap
attr
)
{
attr_
=
attr
;
}
void
SetAttr
(
AttributeMap
attr
)
{
attr_
=
attr
;
}
...
@@ -84,7 +84,7 @@ class OperationExpression {
...
@@ -84,7 +84,7 @@ class OperationExpression {
AttributeMap
attr_
;
AttributeMap
attr_
;
std
::
string
rhs_type_
;
std
::
string
rhs_type_
;
std
::
string
lhs_type_
;
std
::
string
lhs_type_
;
std
::
unordered_map
<
int
,
bool
>
intermediate_state
_
;
std
::
vector
<
int
>
intermediate_output_ids
_
;
};
};
class
TemplateVariable
{
class
TemplateVariable
{
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
浏览文件 @
1be6bf45
...
@@ -144,7 +144,6 @@ void CheckOutput(const std::vector<OperationExpression>& expressions,
...
@@ -144,7 +144,6 @@ void CheckOutput(const std::vector<OperationExpression>& expressions,
LOG
(
INFO
)
<<
"Precision check failed from i = "
<<
id
LOG
(
INFO
)
<<
"Precision check failed from i = "
<<
id
<<
", expect: "
<<
expect
<<
", actual: "
<<
actual
;
<<
", expect: "
<<
expect
<<
", actual: "
<<
actual
;
EXPECT_LT
(
fabs
(
actual
-
expect
),
eps
);
EXPECT_LT
(
fabs
(
actual
-
expect
),
eps
);
break
;
}
}
}
}
}
}
...
@@ -465,7 +464,7 @@ TEST(code_generator, subgraph) {
...
@@ -465,7 +464,7 @@ TEST(code_generator, subgraph) {
for
(
std
::
string
dtype
:
{
"float"
,
"__half"
})
{
for
(
std
::
string
dtype
:
{
"float"
,
"__half"
})
{
std
::
unique_ptr
<
paddle
::
framework
::
ir
::
Graph
>
graph
=
std
::
unique_ptr
<
paddle
::
framework
::
ir
::
Graph
>
graph
=
BuildGraph
(
false
,
dtype
);
BuildGraph
(
false
,
dtype
);
fusion_group
::
SubGraph
subgraph
(
0
,
"elementwise_kernel_1"
,
fals
e
,
fusion_group
::
SubGraph
subgraph
(
0
,
"elementwise_kernel_1"
,
tru
e
,
graph
->
Nodes
());
graph
->
Nodes
());
// Expressions generated by code_generator (they may be different):
// Expressions generated by code_generator (they may be different):
...
@@ -484,7 +483,7 @@ TEST(code_generator, subgraph_grad) {
...
@@ -484,7 +483,7 @@ TEST(code_generator, subgraph_grad) {
for
(
std
::
string
dtype
:
{
"float"
,
"__half"
})
{
for
(
std
::
string
dtype
:
{
"float"
,
"__half"
})
{
std
::
unique_ptr
<
paddle
::
framework
::
ir
::
Graph
>
graph
=
std
::
unique_ptr
<
paddle
::
framework
::
ir
::
Graph
>
graph
=
BuildGraph
(
true
,
dtype
);
BuildGraph
(
true
,
dtype
);
fusion_group
::
SubGraph
subgraph
(
0
,
"elementwise_grad_kernel_1"
,
fals
e
,
fusion_group
::
SubGraph
subgraph
(
0
,
"elementwise_grad_kernel_1"
,
tru
e
,
DistilGradNodes
(
graph
));
DistilGradNodes
(
graph
));
// Expressions generated by code_generator (they may be different):
// Expressions generated by code_generator (they may be different):
...
...
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
浏览文件 @
1be6bf45
...
@@ -63,7 +63,7 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
...
@@ -63,7 +63,7 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
bool
GroupDetector
::
CheckPrecondition
(
const
Node
*
n
)
{
bool
GroupDetector
::
CheckPrecondition
(
const
Node
*
n
)
{
auto
check_data_type
=
[
&
](
const
std
::
vector
<
Node
*>&
nodes
)
->
bool
{
auto
check_data_type
=
[
&
](
const
std
::
vector
<
Node
*>&
nodes
)
->
bool
{
bool
is_first
=
true
;
bool
is_first
=
true
;
proto
::
VarType
::
Type
data_type_0
;
proto
::
VarType
::
Type
data_type_0
=
proto
::
VarType
::
BOOL
;
for
(
auto
*
n
:
nodes
)
{
for
(
auto
*
n
:
nodes
)
{
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
if
(
n
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
if
(
n
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
...
...
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
浏览文件 @
1be6bf45
...
@@ -63,11 +63,6 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
...
@@ -63,11 +63,6 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
std
::
unordered_set
<
Node
*>
(
vec
.
begin
(),
vec
.
end
()));
std
::
unordered_set
<
Node
*>
(
vec
.
begin
(),
vec
.
end
()));
VLOG
(
3
)
<<
"subgraph: {
\n
"
<<
DebugString
(
subgraph
.
SortedNodes
())
<<
"}
\n
"
;
VLOG
(
3
)
<<
"subgraph: {
\n
"
<<
DebugString
(
subgraph
.
SortedNodes
())
<<
"}
\n
"
;
// In elementwise fused kernel, memory is the bound of execution,
// here we remove the output id to use less memory and less time.
if
(
subgraph
.
RemoveIntermediateOut
())
{
subgraph
.
DetectIntermediateOutWithGraph
(
graph
);
}
if
(
subgraph
.
IsValid
(
min_subgraph_size
))
{
if
(
subgraph
.
IsValid
(
min_subgraph_size
))
{
subgraph
.
SetFuncName
(
"fused_elementwise_"
+
std
::
to_string
(
index
++
));
subgraph
.
SetFuncName
(
"fused_elementwise_"
+
std
::
to_string
(
index
++
));
if
(
GenerateCode
(
&
subgraph
))
{
if
(
GenerateCode
(
&
subgraph
))
{
...
@@ -115,57 +110,52 @@ static int ExtractOpRole(fusion_group::SubGraph* subgraph) {
...
@@ -115,57 +110,52 @@ static int ExtractOpRole(fusion_group::SubGraph* subgraph) {
void
FusionGroupPass
::
InsertFusionGroupOp
(
void
FusionGroupPass
::
InsertFusionGroupOp
(
Graph
*
graph
,
fusion_group
::
SubGraph
*
subgraph
)
const
{
Graph
*
graph
,
fusion_group
::
SubGraph
*
subgraph
)
const
{
const
std
::
vector
<
Node
*>&
input_vars_of_subgraph
=
const
std
::
vector
<
Node
*>&
input_vars
=
subgraph
->
GetInputVarNodes
();
subgraph
->
GetInputVarNodes
();
const
std
::
vector
<
Node
*>&
output_vars
=
const
std
::
vector
<
Node
*>&
output_vars_of_subgraph
=
subgraph
->
GetOutputVarNodes
(
subgraph
->
SaveIntermediateOut
());
subgraph
->
GetOutputVarNodes
();
const
std
::
vector
<
Node
*>
intermediate_vars_of_subgraph
=
subgraph
->
GetIntermediateOutVarNodes
();
std
::
unordered_set
<
Node
*>
external_nodes
;
std
::
unordered_set
<
Node
*>
external_nodes
;
OpDesc
op_desc
;
// Prepare inputs.
op_desc
.
SetType
(
"fusion_group"
);
std
::
vector
<
std
::
string
>
input_names
;
std
::
vector
<
std
::
string
>
input_names
;
std
::
vector
<
std
::
string
>
inputs_data_types
;
std
::
vector
<
int
>
input_dtypes
;
for
(
auto
*
n
:
input_vars_of_subgraph
)
{
std
::
unordered_set
<
Node
*>
output_vars_set
(
output_vars
.
begin
(),
output_vars
.
end
());
for
(
auto
*
n
:
input_vars
)
{
// It is not an output var node.
if
(
output_vars_set
.
find
(
n
)
==
output_vars_set
.
end
())
{
input_names
.
push_back
(
n
->
Name
());
input_names
.
push_back
(
n
->
Name
());
inputs_data_types
.
push_back
(
DataTypeToString
(
n
->
Var
()
->
GetDataType
()
));
input_dtypes
.
push_back
(
n
->
Var
()
->
GetDataType
(
));
external_nodes
.
insert
(
n
);
external_nodes
.
insert
(
n
);
}
}
op_desc
.
SetInput
(
"Inputs"
,
input_names
);
}
// Prepare outputs.
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
outs_data_types
;
std
::
vector
<
int
>
output_dtypes
;
std
::
vector
<
Node
*>
output_var_without_intermediate
;
for
(
auto
*
n
:
output_vars
)
{
for
(
auto
*
n
:
output_vars_of_subgraph
)
{
auto
it_input
=
find
(
input_vars_of_subgraph
.
begin
(),
input_vars_of_subgraph
.
end
(),
n
);
auto
it_intermediate
=
find
(
intermediate_vars_of_subgraph
.
begin
(),
intermediate_vars_of_subgraph
.
end
(),
n
);
if
(
it_intermediate
==
intermediate_vars_of_subgraph
.
end
()
&&
it_input
==
input_vars_of_subgraph
.
end
())
{
output_names
.
push_back
(
n
->
Name
());
output_names
.
push_back
(
n
->
Name
());
outs_data_types
.
push_back
(
DataTypeToString
(
n
->
Var
()
->
GetDataType
()));
output_dtypes
.
push_back
(
n
->
Var
()
->
GetDataType
());
output_var_without_intermediate
.
push_back
(
n
);
}
external_nodes
.
insert
(
n
);
external_nodes
.
insert
(
n
);
}
}
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_group"
);
op_desc
.
SetInput
(
"Inputs"
,
input_names
);
op_desc
.
SetOutput
(
"Outs"
,
output_names
);
op_desc
.
SetOutput
(
"Outs"
,
output_names
);
op_desc
.
SetAttr
(
"inputs_d
ata_type"
,
inputs_data_
types
);
op_desc
.
SetAttr
(
"inputs_d
type"
,
input_d
types
);
op_desc
.
SetAttr
(
"outs_d
ata_type"
,
outs_data_
types
);
op_desc
.
SetAttr
(
"outs_d
type"
,
output_d
types
);
op_desc
.
SetAttr
(
"type"
,
subgraph
->
GetType
());
op_desc
.
SetAttr
(
"type"
,
subgraph
->
GetType
());
op_desc
.
SetAttr
(
"func_name"
,
subgraph
->
GetFuncName
());
op_desc
.
SetAttr
(
"func_name"
,
subgraph
->
GetFuncName
());
op_desc
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
op_desc
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
ExtractOpRole
(
subgraph
));
ExtractOpRole
(
subgraph
));
Node
*
fusion_group_node
=
graph
->
CreateOpNode
(
&
op_desc
);
Node
*
fusion_group_node
=
graph
->
CreateOpNode
(
&
op_desc
);
for
(
auto
*
in
:
input_vars_of_subgraph
)
{
for
(
auto
*
in
:
input_vars
)
{
if
(
output_vars_set
.
find
(
in
)
==
output_vars_set
.
end
())
{
IR_NODE_LINK_TO
(
in
,
fusion_group_node
);
IR_NODE_LINK_TO
(
in
,
fusion_group_node
);
}
}
}
for
(
auto
*
out
:
output_var
_without_intermediate
)
{
for
(
auto
*
out
:
output_var
s
)
{
IR_NODE_LINK_TO
(
fusion_group_node
,
out
);
IR_NODE_LINK_TO
(
fusion_group_node
,
out
);
}
}
...
...
paddle/fluid/framework/ir/fusion_group/operation.cc
浏览文件 @
1be6bf45
...
@@ -105,12 +105,6 @@ void OperationMap::InsertUnaryElementwiseOperations() {
...
@@ -105,12 +105,6 @@ void OperationMap::InsertUnaryElementwiseOperations() {
insert_handler
(
"tanh"
,
"%{2.0} / (%{1.0} + Exp(-%{2.0} * ${0})) - %{1.0}"
,
insert_handler
(
"tanh"
,
"%{2.0} / (%{1.0} + Exp(-%{2.0} * ${0})) - %{1.0}"
,
{
"${2} * (%{1.0} - ${1} * ${1})"
});
{
"${2} * (%{1.0} - ${1} * ${1})"
});
// cast:
// out = static_cast<T>(x)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler
(
"cast"
,
"${0}"
,
{});
// sqrt:
// sqrt:
// out = x^(1/2)
// out = x^(1/2)
// dx = dout * 0.5 / out
// dx = dout * 0.5 / out
...
@@ -121,6 +115,16 @@ void OperationMap::InsertUnaryElementwiseOperations() {
...
@@ -121,6 +115,16 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// dx = dout * 2.0 * x
// dx = dout * 2.0 * x
insert_handler
(
"square"
,
"${0} * ${0}"
,
{
"${2} * %{2.0} * ${0}"
});
insert_handler
(
"square"
,
"${0} * ${0}"
,
{
"${2} * %{2.0} * ${0}"
});
// assign:
// out = x
insert_handler
(
"assign"
,
"${0}"
,
{});
// cast:
// out = static_cast<T>(x)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler
(
"cast"
,
"${0}"
,
{});
// scale
// scale
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// here we use '=' operator to seperate th default value
// here we use '=' operator to seperate th default value
...
...
paddle/fluid/framework/ir/fusion_group/subgraph.h
浏览文件 @
1be6bf45
...
@@ -66,11 +66,12 @@ class SubGraph {
...
@@ -66,11 +66,12 @@ class SubGraph {
}
}
int
GetType
()
const
{
return
type_
;
}
int
GetType
()
const
{
return
type_
;
}
bool
RemoveIntermediateOut
()
{
return
!
save_intermediate_out_
;
}
void
SetFuncName
(
std
::
string
func_name
)
{
func_name_
=
func_name
;
}
void
SetFuncName
(
std
::
string
func_name
)
{
func_name_
=
func_name
;
}
std
::
string
GetFuncName
()
const
{
return
func_name_
;
}
std
::
string
GetFuncName
()
const
{
return
func_name_
;
}
bool
SaveIntermediateOut
()
const
{
return
save_intermediate_out_
;
}
const
std
::
unordered_set
<
Node
*>&
Nodes
()
const
{
return
nodes_set_
;
}
const
std
::
unordered_set
<
Node
*>&
Nodes
()
const
{
return
nodes_set_
;
}
const
std
::
vector
<
Node
*>&
SortedNodes
()
{
const
std
::
vector
<
Node
*>&
SortedNodes
()
{
if
(
!
is_sorted_
)
{
if
(
!
is_sorted_
)
{
...
@@ -118,66 +119,88 @@ class SubGraph {
...
@@ -118,66 +119,88 @@ class SubGraph {
return
input_vars
;
return
input_vars
;
}
}
std
::
vector
<
Node
*>
GetOutputVarNodes
()
{
std
::
vector
<
Node
*>
GetOutputVarNodes
(
bool
with_intermediate_out
)
{
// The order of output nodes should be consistant anywhere..
// The order of output nodes should be consistant anywhere..
std
::
vector
<
Node
*>
output_vars
_all
;
std
::
vector
<
Node
*>
output_vars
;
for
(
auto
*
n
:
SortedNodes
())
{
for
(
auto
*
n
:
SortedNodes
())
{
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
(
))
{
if
(
IsOutputOfInternalOp
(
n
))
{
// If the var_node is the output of some op_node in the subgraph, it
// If the var_node is the output of some op_node in the subgraph, it
// is considered the output var node of the subgraph.
// is considered the output var node of the subgraph.
bool
is_found
=
false
;
if
(
with_intermediate_out
)
{
for
(
auto
*
in
:
n
->
inputs
)
{
output_vars
.
push_back
(
n
);
if
(
Has
(
in
))
{
}
else
{
is_found
=
true
;
if
(
n
->
outputs
.
empty
()
||
IsInputOfExternalOp
(
n
))
{
}
output_vars
.
push_back
(
n
);
}
}
if
(
is_found
)
{
output_vars_all
.
push_back
(
n
);
}
}
}
}
}
}
return
output_vars
_all
;
return
output_vars
;
}
}
std
::
vector
<
Node
*>
GetIntermediateOutVarNodes
()
{
std
::
vector
<
Node
*>
GetIntermediateOutVarNodes
()
{
return
intermediate_out_nodes_
;
// Intermediate output var nodes: the output of some op_node in the
// subgraph, but not referenced outside the subgraph.
std
::
vector
<
Node
*>
intermediate_out_vars
;
for
(
auto
*
n
:
SortedNodes
())
{
if
(
IsOutputOfInternalOp
(
n
)
&&
IsInputOfInternalOp
(
n
)
&&
!
IsInputOfExternalOp
(
n
))
{
// When the outputs size is 0, it is also considered a intermidiate
// output. It maybe an unused output or the fetching vars, so that we
// cannot eleiminate it directly here.
intermediate_out_vars
.
push_back
(
n
);
}
}
return
intermediate_out_vars
;
}
}
void
DetectIntermediateOutWithGraph
(
Graph
*
graph
)
{
std
::
unordered_set
<
Node
*>
GetIntermediateOutVarNodesSet
(
)
{
auto
graph_nodes
=
graph
->
Nodes
();
std
::
vector
<
Node
*>
intermediate_out_vars
=
GetIntermediateOutVar
Nodes
();
return
std
::
unordered_set
<
Node
*>
(
intermediate_out_vars
.
begin
(),
for
(
auto
*
n
:
SortedNodes
())
{
intermediate_out_vars
.
end
());
bool
enable_remove
=
true
;
}
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
private:
bool
leaf_graph
=
true
;
bool
IsInputOfInternalOp
(
Node
*
n
)
{
for
(
auto
*
node
:
graph_nodes
)
{
bool
is_input_of_internal_op
=
false
;
if
(
node
->
IsOp
())
{
if
(
Has
(
n
)
&&
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
auto
inputs
=
node
->
inputs
;
for
(
auto
*
out
:
n
->
outputs
)
{
for
(
auto
*
in
:
inputs
)
{
if
(
Has
(
out
)
)
{
if
(
in
&&
in
->
Name
()
==
n
->
Name
())
{
is_input_of_internal_op
=
true
;
if
(
!
Has
(
node
))
enable_remove
=
false
;
break
;
leaf_graph
=
false
;
}
}
}
}
}
return
is_input_of_internal_op
;
}
}
if
(
!
enable_remove
)
{
bool
IsInputOfExternalOp
(
Node
*
n
)
{
// If n is the input any one node outside the subgraph.
bool
is_input_of_external_op
=
false
;
if
(
Has
(
n
)
&&
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
for
(
auto
*
out
:
n
->
outputs
)
{
if
(
!
Has
(
out
))
{
is_input_of_external_op
=
true
;
break
;
break
;
}
}
}
}
if
(
leaf_graph
)
enable_remove
=
false
;
}
return
is_input_of_external_op
;
}
else
{
enable_remove
=
false
;
}
}
if
(
enable_remove
)
{
bool
IsOutputOfInternalOp
(
Node
*
n
)
{
intermediate_out_nodes_
.
push_back
(
n
);
bool
is_output_of_internal_op
=
false
;
if
(
Has
(
n
)
&&
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
for
(
auto
*
in
:
n
->
inputs
)
{
if
(
Has
(
in
))
{
is_output_of_internal_op
=
true
;
break
;
}
}
}
}
}
return
is_output_of_internal_op
;
}
}
private:
void
TopologicalSort
()
{
void
TopologicalSort
()
{
if
(
!
is_sorted_
)
{
if
(
!
is_sorted_
)
{
std
::
unordered_map
<
Node
*
,
std
::
vector
<
Node
*>>
inputs_map
;
std
::
unordered_map
<
Node
*
,
std
::
vector
<
Node
*>>
inputs_map
;
...
@@ -236,7 +259,6 @@ class SubGraph {
...
@@ -236,7 +259,6 @@ class SubGraph {
bool
save_intermediate_out_
{
true
};
bool
save_intermediate_out_
{
true
};
std
::
unordered_set
<
Node
*>
nodes_set_
;
std
::
unordered_set
<
Node
*>
nodes_set_
;
std
::
vector
<
Node
*>
intermediate_out_nodes_
{};
bool
is_sorted_
{
false
};
bool
is_sorted_
{
false
};
std
::
vector
<
Node
*>
sorted_nodes_
;
std
::
vector
<
Node
*>
sorted_nodes_
;
};
};
...
...
paddle/fluid/operators/fused/fusion_group_op.cc
浏览文件 @
1be6bf45
...
@@ -22,8 +22,14 @@ class FusionGroupOp : public framework::OperatorWithKernel {
...
@@ -22,8 +22,14 @@ class FusionGroupOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
const
size_t
num_ins
=
ctx
->
Inputs
(
"Inputs"
).
size
();
OP_INOUT_CHECK
(
ctx
->
HasInputs
(
"Inputs"
),
"Input"
,
"Inputs"
,
"FusionGroup"
);
const
size_t
num_outs
=
ctx
->
Outputs
(
"Outs"
).
size
();
OP_INOUT_CHECK
(
ctx
->
HasOutputs
(
"Outs"
),
"Output"
,
"Outs"
,
"FusionGroup"
);
auto
input_names
=
ctx
->
Inputs
(
"Inputs"
);
auto
output_names
=
ctx
->
Outputs
(
"Outs"
);
const
size_t
num_ins
=
input_names
.
size
();
const
size_t
num_outs
=
output_names
.
size
();
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
num_ins
,
1UL
,
num_ins
,
1UL
,
...
@@ -42,9 +48,12 @@ class FusionGroupOp : public framework::OperatorWithKernel {
...
@@ -42,9 +48,12 @@ class FusionGroupOp : public framework::OperatorWithKernel {
std
::
vector
<
framework
::
DDim
>
x_dims
=
ctx
->
GetInputsDim
(
"Inputs"
);
std
::
vector
<
framework
::
DDim
>
x_dims
=
ctx
->
GetInputsDim
(
"Inputs"
);
if
(
type
==
0
)
{
if
(
type
==
0
)
{
for
(
size_t
i
=
1
;
i
<
num_ins
;
++
i
)
{
for
(
size_t
i
=
1
;
i
<
num_ins
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
0
],
x_dims
[
i
],
PADDLE_ENFORCE_EQ
(
x_dims
[
0
],
x_dims
[
i
],
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"All the inputs' dims should be the same."
));
"All the inputs' dims is expected to be the same. "
"But recieved [%s] (name: %s) vs [%s] (name: %s)."
,
x_dims
[
0
],
input_names
[
0
],
x_dims
[
i
],
input_names
[
i
]));
}
}
std
::
vector
<
framework
::
DDim
>
out_dims
;
std
::
vector
<
framework
::
DDim
>
out_dims
;
for
(
size_t
j
=
0
;
j
<
num_outs
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
num_outs
;
++
j
)
{
...
@@ -76,11 +85,11 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -76,11 +85,11 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Outs"
,
AddOutput
(
"Outs"
,
"(std::vector<LoDTensor>) The outputs of fusion_group op."
)
"(std::vector<LoDTensor>) The outputs of fusion_group op."
)
.
AsDuplicable
();
.
AsDuplicable
();
AddAttr
<
std
::
vector
<
std
::
string
>>
(
AddAttr
<
std
::
vector
<
int
>>
(
"outs_dtype"
,
"outs_data_type"
,
"The data type of Outputs in fusion_group op."
)
"The data type of Outputs in fusion_group op."
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
AddAttr
<
std
::
vector
<
int
>>
(
"inputs_dtype"
,
"inputs_data_type"
,
"The data type of Inputs in fusion_group op."
)
"The data type of Inputs in fusion_group op."
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
int
>
(
"type"
,
"Fusion type."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"type"
,
"Fusion type."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"func_name"
,
"Name of the generated functions."
)
AddAttr
<
std
::
string
>
(
"func_name"
,
"Name of the generated functions."
)
...
...
paddle/fluid/operators/fused/fusion_group_op.h
浏览文件 @
1be6bf45
...
@@ -24,14 +24,14 @@ namespace operators {
...
@@ -24,14 +24,14 @@ namespace operators {
static
void
MutableMultiTypeData
(
static
void
MutableMultiTypeData
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>*
var
,
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>*
var
,
const
std
::
vector
<
std
::
string
>&
data_type
,
const
platform
::
Place
&
place
)
{
const
std
::
vector
<
int
>&
data_type
,
const
platform
::
Place
&
place
)
{
for
(
size_t
i
=
0
;
i
<
var
->
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
var
->
size
();
i
++
)
{
if
(
data_type
[
i
]
==
"float"
)
{
if
(
data_type
[
i
]
==
framework
::
proto
::
VarType
::
FP32
)
{
(
*
var
)[
i
]
->
mutable_data
<
float
>
(
place
);
(
*
var
)[
i
]
->
mutable_data
<
float
>
(
place
);
}
else
if
(
data_type
[
i
]
==
"double"
)
{
}
else
if
(
data_type
[
i
]
==
framework
::
proto
::
VarType
::
FP16
)
{
(
*
var
)[
i
]
->
mutable_data
<
double
>
(
place
);
}
else
if
(
data_type
[
i
]
==
"::paddle::platform::float16"
)
{
(
*
var
)[
i
]
->
mutable_data
<
paddle
::
platform
::
float16
>
(
place
);
(
*
var
)[
i
]
->
mutable_data
<
paddle
::
platform
::
float16
>
(
place
);
}
else
if
(
data_type
[
i
]
==
framework
::
proto
::
VarType
::
FP64
)
{
(
*
var
)[
i
]
->
mutable_data
<
double
>
(
place
);
}
}
}
}
}
}
...
@@ -43,15 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> {
...
@@ -43,15 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> {
auto
ins
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Inputs"
);
auto
ins
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Inputs"
);
auto
outs
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"Outs"
);
auto
outs
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"Outs"
);
int
type
=
ctx
.
Attr
<
int
>
(
"type"
);
int
type
=
ctx
.
Attr
<
int
>
(
"type"
);
auto
outs_type
=
ctx
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"outs_data_
type"
);
const
auto
&
outs_dtype
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"outs_d
type"
);
auto
inputs_type
=
ctx
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"inputs_data_
type"
);
const
auto
&
inputs_dtype
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"inputs_d
type"
);
size_t
num_ins
=
ins
.
size
();
size_t
num_ins
=
ins
.
size
();
size_t
num_outs
=
outs
.
size
();
size_t
num_outs
=
outs
.
size
();
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
MutableMultiTypeData
(
&
outs
,
outs_type
,
place
);
MutableMultiTypeData
(
&
outs
,
outs_
d
type
,
place
);
std
::
string
func_name
=
ctx
.
Attr
<
std
::
string
>
(
"func_name"
);
std
::
string
func_name
=
ctx
.
Attr
<
std
::
string
>
(
"func_name"
);
platform
::
DeviceCode
*
dev_code
=
platform
::
DeviceCode
*
dev_code
=
...
@@ -64,22 +64,22 @@ class FusionGroupKernel : public framework::OpKernel<T> {
...
@@ -64,22 +64,22 @@ class FusionGroupKernel : public framework::OpKernel<T> {
args
.
push_back
(
&
n
);
args
.
push_back
(
&
n
);
std
::
vector
<
const
void
*>
ptrs
(
num_ins
+
num_outs
);
std
::
vector
<
const
void
*>
ptrs
(
num_ins
+
num_outs
);
for
(
size_t
i
=
0
;
i
<
num_ins
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_ins
;
++
i
)
{
if
(
inputs_
type
[
i
]
==
"::paddle::platform::float16"
)
{
if
(
inputs_
dtype
[
i
]
==
framework
::
proto
::
VarType
::
FP16
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
paddle
::
platform
::
float16
>
();
ptrs
[
i
]
=
ins
[
i
]
->
data
<
paddle
::
platform
::
float16
>
();
}
else
if
(
inputs_type
[
i
]
==
"double"
)
{
}
else
if
(
inputs_dtype
[
i
]
==
framework
::
proto
::
VarType
::
FP32
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
double
>
();
}
else
if
(
inputs_type
[
i
]
==
"float"
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
float
>
();
ptrs
[
i
]
=
ins
[
i
]
->
data
<
float
>
();
}
else
if
(
inputs_dtype
[
i
]
==
framework
::
proto
::
VarType
::
FP64
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
double
>
();
}
}
args
.
push_back
(
&
ptrs
[
i
]);
args
.
push_back
(
&
ptrs
[
i
]);
}
}
for
(
size_t
j
=
0
;
j
<
num_outs
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
num_outs
;
++
j
)
{
if
(
outs_
type
[
j
]
==
"::paddle::platform::float16"
)
{
if
(
outs_
dtype
[
j
]
==
framework
::
proto
::
VarType
::
FP16
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
paddle
::
platform
::
float16
>
();
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
paddle
::
platform
::
float16
>
();
}
else
if
(
outs_type
[
j
]
==
"double"
)
{
}
else
if
(
outs_dtype
[
j
]
==
framework
::
proto
::
VarType
::
FP32
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
double
>
();
}
else
if
(
outs_type
[
j
]
==
"float"
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
float
>
();
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
float
>
();
}
else
if
(
outs_dtype
[
j
]
==
framework
::
proto
::
VarType
::
FP64
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
double
>
();
}
}
args
.
push_back
(
&
ptrs
[
num_ins
+
j
]);
args
.
push_back
(
&
ptrs
[
num_ins
+
j
]);
}
}
...
...
paddle/fluid/operators/fused/fusion_group_op_test.cc
浏览文件 @
1be6bf45
...
@@ -57,10 +57,14 @@ framework::OpDesc* CreateFusionGroupOp(
...
@@ -57,10 +57,14 @@ framework::OpDesc* CreateFusionGroupOp(
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
const
std
::
vector
<
std
::
string
>&
inputs_data_type
,
std
::
string
func_name
)
{
const
std
::
vector
<
std
::
string
>&
outs_data_type
,
std
::
string
func_name
)
{
EXPECT_EQ
(
input_names
.
size
(),
input_shapes
.
size
());
EXPECT_EQ
(
input_names
.
size
(),
input_shapes
.
size
());
std
::
vector
<
int
>
input_dtypes
(
input_names
.
size
(),
framework
::
proto
::
VarType
::
FP32
);
std
::
vector
<
int
>
output_dtypes
(
output_names
.
size
(),
framework
::
proto
::
VarType
::
FP32
);
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
auto
*
var
=
program
->
MutableBlock
(
0
)
->
Var
(
input_names
[
i
]);
auto
*
var
=
program
->
MutableBlock
(
0
)
->
Var
(
input_names
[
i
]);
var
->
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
var
->
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
...
@@ -77,8 +81,8 @@ framework::OpDesc* CreateFusionGroupOp(
...
@@ -77,8 +81,8 @@ framework::OpDesc* CreateFusionGroupOp(
op
->
SetType
(
"fusion_group"
);
op
->
SetType
(
"fusion_group"
);
op
->
SetInput
(
"Inputs"
,
input_names
);
op
->
SetInput
(
"Inputs"
,
input_names
);
op
->
SetOutput
(
"Outs"
,
output_names
);
op
->
SetOutput
(
"Outs"
,
output_names
);
op
->
SetAttr
(
"inputs_d
ata_type"
,
inputs_data_type
);
op
->
SetAttr
(
"inputs_d
type"
,
input_dtypes
);
op
->
SetAttr
(
"outs_d
ata_type"
,
outs_data_type
);
op
->
SetAttr
(
"outs_d
type"
,
output_dtypes
);
op
->
SetAttr
(
"type"
,
type
);
op
->
SetAttr
(
"type"
,
type
);
op
->
SetAttr
(
"func_name"
,
func_name
);
op
->
SetAttr
(
"func_name"
,
func_name
);
op
->
SetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
op
->
SetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
...
@@ -133,8 +137,6 @@ void CheckOutputs(framework::Scope* scope,
...
@@ -133,8 +137,6 @@ void CheckOutputs(framework::Scope* scope,
void
TestMain
(
const
std
::
vector
<
std
::
string
>&
input_names
,
void
TestMain
(
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
const
std
::
vector
<
std
::
string
>&
inputs_data_type
,
const
std
::
vector
<
std
::
string
>&
outs_data_type
,
std
::
string
func_name
,
std
::
string
cuda_kernel_str
,
std
::
string
func_name
,
std
::
string
cuda_kernel_str
,
CPUKernelFunc
cpu_kernel_func
)
{
CPUKernelFunc
cpu_kernel_func
)
{
// Compile the device code
// Compile the device code
...
@@ -144,9 +146,8 @@ void TestMain(const std::vector<std::string>& input_names,
...
@@ -144,9 +146,8 @@ void TestMain(const std::vector<std::string>& input_names,
// Create a ProgramDesc that has a fusion_group_op.
// Create a ProgramDesc that has a fusion_group_op.
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
framework
::
OpDesc
*
op_desc
=
framework
::
OpDesc
*
op_desc
=
CreateFusionGroupOp
(
CreateFusionGroupOp
(
&
program
,
input_names
,
input_shapes
,
output_names
,
&
program
,
input_names
,
input_shapes
,
output_names
,
type
,
func_name
);
type
,
inputs_data_type
,
outs_data_type
,
func_name
);
auto
fusion_group_op
=
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
auto
fusion_group_op
=
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
framework
::
Scope
scope
;
framework
::
Scope
scope
;
...
@@ -216,11 +217,8 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
...
@@ -216,11 +217,8 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
}
}
};
};
std
::
vector
<
std
::
string
>
inputs_data_type
(
input_names
.
size
(),
"float"
);
TestMain
(
input_names
,
input_shapes
,
output_names
,
0
,
std
::
vector
<
std
::
string
>
outs_data_type
(
output_names
.
size
(),
"float"
);
"elementwise_cuda_kernel_0"
,
kernel
,
elementwise_cpu_kernel_0
);
TestMain
(
input_names
,
input_shapes
,
output_names
,
0
,
inputs_data_type
,
outs_data_type
,
"elementwise_cuda_kernel_0"
,
kernel
,
elementwise_cpu_kernel_0
);
}
}
}
// namespace operators
}
// namespace operators
...
...
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
浏览文件 @
1be6bf45
...
@@ -77,12 +77,13 @@ class FusionGroupPassTest(PassTest):
...
@@ -77,12 +77,13 @@ class FusionGroupPassTest(PassTest):
self
.
check_output_with_place
(
fluid
.
CUDAPlace
(
0
))
self
.
check_output_with_place
(
fluid
.
CUDAPlace
(
0
))
class
FusionGroupPass
Test1
(
FusionGroupPassTest
):
class
FusionGroupPass
ComplicatedTest
(
FusionGroupPassTest
):
def
build_program
(
self
,
dtype
):
def
build_program
(
self
,
dtype
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
32
,
128
],
dtype
,
5
)
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
32
,
64
],
dtype
,
5
)
tmp_0
=
layers
.
assign
(
self
.
feed_vars
[
0
])
one
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
dtype
,
value
=
1.0
)
tmp_0
=
one
*
self
.
feed_vars
[
0
]
# subgraph with 9 op nodes
# subgraph with 9 op nodes
tmp_1
=
tmp_0
*
layers
.
sigmoid
(
self
.
feed_vars
[
1
])
+
layers
.
sigmoid
(
tmp_1
=
tmp_0
*
layers
.
sigmoid
(
self
.
feed_vars
[
1
])
+
layers
.
sigmoid
(
self
.
feed_vars
[
2
])
*
layers
.
tanh
(
self
.
feed_vars
[
3
])
self
.
feed_vars
[
2
])
*
layers
.
tanh
(
self
.
feed_vars
[
3
])
...
@@ -94,7 +95,7 @@ class FusionGroupPassTest1(FusionGroupPassTest):
...
@@ -94,7 +95,7 @@ class FusionGroupPassTest1(FusionGroupPassTest):
self
.
fetch_list
=
[
tmp_2
,
self
.
grad
(
tmp_0
)]
self
.
fetch_list
=
[
tmp_2
,
self
.
grad
(
tmp_0
)]
class
FusionGroupPass
Test2
(
FusionGroupPassTest
):
class
FusionGroupPass
InplaceTest
(
FusionGroupPassTest
):
def
build_program
(
self
,
dtype
):
def
build_program
(
self
,
dtype
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
32
,
128
],
dtype
,
3
)
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
32
,
128
],
dtype
,
3
)
...
@@ -103,15 +104,13 @@ class FusionGroupPassTest2(FusionGroupPassTest):
...
@@ -103,15 +104,13 @@ class FusionGroupPassTest2(FusionGroupPassTest):
name
=
"data3"
,
shape
=
[
128
,
32
],
dtype
=
dtype
))
name
=
"data3"
,
shape
=
[
128
,
32
],
dtype
=
dtype
))
# subgraph with 3 op node
# subgraph with 3 op node
tmp_0
=
self
.
feed_vars
[
0
]
+
self
.
feed_vars
[
1
]
tmp_0
=
self
.
feed_vars
[
0
]
-
self
.
feed_vars
[
1
]
tmp_1
=
layers
.
relu
(
self
.
feed_vars
[
2
]
*
tmp_0
)
tmp_1
=
tmp_0
*
self
.
feed_vars
[
2
]
# subgraph with 2 op nodes
tmp_2
=
layers
.
assign
(
tmp_1
,
output
=
tmp_0
)
tmp_2
=
layers
.
relu
(
layers
.
sigmoid
(
self
.
feed_vars
[
3
]))
tmp_3
=
layers
.
mul
(
tmp_2
,
self
.
feed_vars
[
3
])
tmp_3
=
layers
.
mul
(
tmp_1
,
tmp_2
)
self
.
append_gradients
(
tmp_3
)
self
.
num_fused_ops
=
1
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_3
]
self
.
fetch_list
=
[
tmp_3
,
self
.
grad
(
tmp_1
)]
class
FusionGroupPassTestFP64
(
FusionGroupPassTest
):
class
FusionGroupPassTestFP64
(
FusionGroupPassTest
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录