Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f0d193a2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f0d193a2
编写于
3月 12, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
3月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cast fusion for fusion group (#22876)
* add support for expression type convert and add cast Op support in fusion group
上级
29a7a52d
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
337 addition
and
135 deletion
+337
-135
paddle/fluid/framework/ir/fusion_group/code_generator.cc
paddle/fluid/framework/ir/fusion_group/code_generator.cc
+79
-36
paddle/fluid/framework/ir/fusion_group/code_generator.h
paddle/fluid/framework/ir/fusion_group/code_generator.h
+5
-3
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
.../fluid/framework/ir/fusion_group/code_generator_helper.cc
+45
-6
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
...e/fluid/framework/ir/fusion_group/code_generator_helper.h
+18
-4
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
.../fluid/framework/ir/fusion_group/code_generator_tester.cc
+32
-29
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
...d/framework/ir/fusion_group/elementwise_group_detector.cc
+47
-1
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h
...id/framework/ir/fusion_group/elementwise_group_detector.h
+6
-1
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
+8
-0
paddle/fluid/framework/ir/fusion_group/operation.cc
paddle/fluid/framework/ir/fusion_group/operation.cc
+10
-1
paddle/fluid/framework/ir/fusion_group/subgraph.h
paddle/fluid/framework/ir/fusion_group/subgraph.h
+1
-34
paddle/fluid/operators/fused/fusion_group_op.cc
paddle/fluid/operators/fused/fusion_group_op.cc
+14
-1
paddle/fluid/operators/fused/fusion_group_op.h
paddle/fluid/operators/fused/fusion_group_op.h
+33
-6
paddle/fluid/operators/fused/fusion_group_op_test.cc
paddle/fluid/operators/fused/fusion_group_op_test.cc
+14
-5
python/paddle/fluid/tests/unittests/ir/pass_test.py
python/paddle/fluid/tests/unittests/ir/pass_test.py
+2
-2
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
...dle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
+23
-6
未找到文件。
paddle/fluid/framework/ir/fusion_group/code_generator.cc
浏览文件 @
f0d193a2
...
@@ -24,6 +24,21 @@ namespace framework {
...
@@ -24,6 +24,21 @@ namespace framework {
namespace
ir
{
namespace
ir
{
namespace
fusion_group
{
namespace
fusion_group
{
std
::
string
ExtractDataType
(
const
std
::
vector
<
Node
*>
nodes
)
{
std
::
string
dtype_str
=
"float"
;
auto
data_type
=
nodes
.
back
()
->
Var
()
->
GetDataType
();
if
(
data_type
==
proto
::
VarType
::
FP32
)
{
dtype_str
=
"float"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP64
)
{
dtype_str
=
"double"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP16
)
{
dtype_str
=
"float16"
;
}
return
dtype_str
;
}
CodeGenerator
::
CodeGenerator
()
{
CodeGenerator
::
CodeGenerator
()
{
// Only support elementwise operations now.
// Only support elementwise operations now.
code_templates_
.
resize
(
1
);
code_templates_
.
resize
(
1
);
...
@@ -34,8 +49,7 @@ CodeGenerator::CodeGenerator() {
...
@@ -34,8 +49,7 @@ CodeGenerator::CodeGenerator() {
std
::
string
CodeGenerator
::
Generate
(
SubGraph
*
subgraph
)
{
std
::
string
CodeGenerator
::
Generate
(
SubGraph
*
subgraph
)
{
std
::
vector
<
OperationExpression
>
expressions
=
ConvertToExpressions
(
subgraph
);
std
::
vector
<
OperationExpression
>
expressions
=
ConvertToExpressions
(
subgraph
);
return
Generate
(
subgraph
->
GetFuncName
(),
subgraph
->
GetDataType
(),
return
Generate
(
subgraph
->
GetFuncName
(),
expressions
);
expressions
);
}
}
static
bool
HasInput
(
Node
*
n
,
std
::
string
name
)
{
static
bool
HasInput
(
Node
*
n
,
std
::
string
name
)
{
...
@@ -95,8 +109,11 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
...
@@ -95,8 +109,11 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
"Output(%s) of operation %s is not set."
,
name
,
op
->
Type
()));
"Output(%s) of operation %s is not set."
,
name
,
op
->
Type
()));
output_ids
.
push_back
(
var_ids
[
op
->
Output
(
name
)[
0
]]);
output_ids
.
push_back
(
var_ids
[
op
->
Output
(
name
)[
0
]]);
}
}
expressions
.
push_back
(
OperationExpression
(
node
->
Name
(),
input_ids
,
output_ids
));
std
::
string
lhs_type
=
ExtractDataType
(
node
->
outputs
);
std
::
string
rhs_type
=
ExtractDataType
(
node
->
inputs
);
expressions
.
emplace_back
(
OperationExpression
(
node
->
Name
(),
input_ids
,
output_ids
,
rhs_type
,
lhs_type
));
}
}
}
}
return
expressions
;
return
expressions
;
...
@@ -105,25 +122,32 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
...
@@ -105,25 +122,32 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// In order to get the right result of expression, we need to calculate and
// In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector.
// store the expression as suffix Expressions using vector.
std
::
string
CodeGenerator
::
Generate
(
std
::
string
CodeGenerator
::
Generate
(
std
::
string
func_name
,
std
::
string
dtype
,
std
::
string
func_name
,
const
std
::
vector
<
OperationExpression
>&
expressions
)
{
const
std
::
vector
<
OperationExpression
>&
expressions
)
{
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std
::
set
<
int
>
input_ids
=
DistilInputIds
(
expressions
);
std
::
set
<
int
>
input_ids
=
DistilInputIds
(
expressions
);
std
::
set
<
int
>
output_ids
=
DistilOutputIds
(
expressions
);
std
::
set
<
int
>
output_ids
=
DistilOutputIds
(
expressions
);
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
=
DistilDtypes
(
expressions
);
TemplateVariable
template_var
;
TemplateVariable
template_var
;
template_var
.
Add
(
"func_name"
,
func_name
);
template_var
.
Add
(
"func_name"
,
func_name
);
template_var
.
Add
(
"parameters"
,
EmitParameters
(
input_ids
,
output_ids
,
dtype
));
template_var
.
Add
(
"parameters"
,
EmitParameters
(
input_ids
,
output_ids
,
dtype
s
));
template_var
.
Add
(
"compute_body"
,
template_var
.
Add
(
"compute_body"
,
EmitComputeBody
(
expressions
,
input_ids
,
output_ids
,
dtype
));
EmitComputeBody
(
expressions
,
input_ids
,
output_ids
,
dtype
s
));
std
::
string
predefined_cuda_functions
;
std
::
set
<
std
::
string
>
all_dtype
;
if
(
dtype
==
"float"
)
{
for
(
const
auto
&
type
:
dtypes
)
{
predefined_cuda_functions
=
predefined_cuda_functions_fp32
;
all_dtype
.
insert
(
type
.
second
);
}
else
if
(
dtype
==
"double"
)
{
}
predefined_cuda_functions
=
predefined_cuda_functions_fp64
;
std
::
string
predefined_cuda_functions
=
""
;
}
else
if
(
dtype
==
"float16"
)
{
if
(
all_dtype
.
find
(
"float"
)
!=
all_dtype
.
end
()
&&
predefined_cuda_functions
=
predefined_cuda_functions_fp16
;
all_dtype
.
find
(
"float16"
)
==
all_dtype
.
end
())
{
predefined_cuda_functions
+=
predefined_cuda_functions_fp32
;
}
if
(
all_dtype
.
find
(
"double"
)
!=
all_dtype
.
end
())
{
predefined_cuda_functions
+=
predefined_cuda_functions_fp64
;
}
if
(
all_dtype
.
find
(
"float16"
)
!=
all_dtype
.
end
())
{
predefined_cuda_functions
+=
predefined_cuda_functions_fp16
;
}
}
return
predefined_cuda_functions
+
code_templates_
[
0
].
Format
(
template_var
);
return
predefined_cuda_functions
+
code_templates_
[
0
].
Format
(
template_var
);
}
}
...
@@ -154,10 +178,40 @@ std::set<int> CodeGenerator::DistilOutputIds(
...
@@ -154,10 +178,40 @@ std::set<int> CodeGenerator::DistilOutputIds(
return
output_ids
;
return
output_ids
;
}
}
std
::
unordered_map
<
int
,
std
::
string
>
CodeGenerator
::
DistilDtypes
(
const
std
::
vector
<
OperationExpression
>&
expressions
)
{
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
;
for
(
const
auto
&
expression
:
expressions
)
{
for
(
auto
id
:
expression
.
GetInputIds
())
{
auto
dtype
=
expression
.
GetRHSType
();
if
(
dtypes
.
find
(
id
)
==
dtypes
.
end
())
{
dtypes
[
id
]
=
dtype
;
}
else
{
PADDLE_ENFORCE_EQ
(
dtypes
[
id
],
dtype
,
platform
::
errors
::
PreconditionNotMet
(
"In fusion group, Same Node id must have same date type"
));
}
}
for
(
auto
id
:
expression
.
GetOutputIds
())
{
auto
dtype
=
expression
.
GetLHSType
();
if
(
dtypes
.
find
(
id
)
==
dtypes
.
end
())
{
dtypes
[
id
]
=
dtype
;
}
else
{
PADDLE_ENFORCE_EQ
(
dtypes
[
id
],
dtype
,
platform
::
errors
::
PreconditionNotMet
(
"In fusion group, Same Node id must have same date type"
));
}
}
}
return
dtypes
;
}
// we get the parameter list code for the expression information
// we get the parameter list code for the expression information
std
::
string
CodeGenerator
::
EmitParameters
(
const
std
::
set
<
int
>&
input_ids
,
std
::
string
CodeGenerator
::
EmitParameters
(
const
std
::
set
<
int
>&
output_ids
,
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
std
::
string
dtype
)
{
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
)
{
std
::
stringstream
ret
;
std
::
stringstream
ret
;
ret
<<
"int N, "
;
ret
<<
"int N, "
;
...
@@ -165,13 +219,13 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
...
@@ -165,13 +219,13 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
// from the input list.
// from the input list.
for
(
auto
id
:
input_ids
)
{
for
(
auto
id
:
input_ids
)
{
if
(
output_ids
.
find
(
id
)
==
output_ids
.
end
())
{
if
(
output_ids
.
find
(
id
)
==
output_ids
.
end
())
{
ret
<<
dtype
<<
"* "
<<
ArgName
(
id
)
<<
", "
;
ret
<<
dtype
s
[
id
]
<<
"* "
<<
ArgName
(
id
)
<<
", "
;
}
}
}
}
size_t
index
=
0
;
size_t
index
=
0
;
for
(
auto
id
:
output_ids
)
{
for
(
auto
id
:
output_ids
)
{
ret
<<
dtype
<<
"* "
<<
ArgName
(
id
);
ret
<<
dtype
s
[
id
]
<<
"* "
<<
ArgName
(
id
);
if
(
index
!=
output_ids
.
size
()
-
1
)
{
if
(
index
!=
output_ids
.
size
()
-
1
)
{
ret
<<
", "
;
ret
<<
", "
;
}
}
...
@@ -184,13 +238,12 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
...
@@ -184,13 +238,12 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
std
::
string
CodeGenerator
::
EmitComputeBody
(
std
::
string
CodeGenerator
::
EmitComputeBody
(
const
std
::
vector
<
OperationExpression
>&
expressions
,
const
std
::
vector
<
OperationExpression
>&
expressions
,
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
std
::
string
dtype
)
{
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
)
{
std
::
ostringstream
compute
;
std
::
ostringstream
compute
;
std
::
unordered_set
<
int
>
used
;
std
::
unordered_set
<
int
>
used
;
std
::
string
compute_dtype
=
(
dtype
==
"float16"
)
?
"float"
:
dtype
;
for
(
size_t
i
=
0
;
i
<
expressions
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
expressions
.
size
();
i
++
)
{
VLOG
(
3
)
<<
DebugString
(
expressions
[
i
]);
VLOG
(
3
)
<<
DebugString
(
expressions
[
i
]);
compute
<<
expressions
[
i
].
GetExpression
(
compute_dtype
,
&
used
);
compute
<<
expressions
[
i
].
GetExpression
(
&
used
);
}
}
// Load input to temporal variables.
// Load input to temporal variables.
...
@@ -198,23 +251,13 @@ std::string CodeGenerator::EmitComputeBody(
...
@@ -198,23 +251,13 @@ std::string CodeGenerator::EmitComputeBody(
for
(
auto
id
:
input_ids
)
{
for
(
auto
id
:
input_ids
)
{
if
(
output_ids
.
find
(
id
)
==
output_ids
.
end
()
&&
if
(
output_ids
.
find
(
id
)
==
output_ids
.
end
()
&&
used
.
find
(
id
)
!=
used
.
end
())
{
used
.
find
(
id
)
!=
used
.
end
())
{
if
(
dtype
==
"float16"
)
{
load
<<
dtypes
[
id
]
<<
" "
<<
TmpName
(
id
)
<<
" = "
<<
VarName
(
id
)
<<
";"
;
load
<<
"float "
<<
TmpName
(
id
)
<<
" = __half2float("
<<
ArgName
(
id
)
<<
"[idx]);"
;
}
else
{
load
<<
dtype
<<
" "
<<
TmpName
(
id
)
<<
" = "
<<
ArgName
(
id
)
<<
"[idx];"
;
}
}
}
}
}
// Store temporal variables to memory.
// Store temporal variables to memory.
std
::
ostringstream
store
;
std
::
ostringstream
store
;
for
(
auto
id
:
output_ids
)
{
for
(
auto
id
:
output_ids
)
{
if
(
dtype
==
"float16"
)
{
store
<<
VarName
(
id
)
<<
" = "
<<
TmpName
(
id
)
<<
";"
;
store
<<
ArgName
(
id
)
<<
"[idx] = __float2half("
<<
TmpName
(
id
)
<<
");"
;
}
else
{
store
<<
ArgName
(
id
)
<<
"[idx] = "
<<
TmpName
(
id
)
<<
";"
;
}
}
}
return
load
.
str
()
+
compute
.
str
()
+
store
.
str
();
return
load
.
str
()
+
compute
.
str
()
+
store
.
str
();
...
...
paddle/fluid/framework/ir/fusion_group/code_generator.h
浏览文件 @
f0d193a2
...
@@ -30,7 +30,7 @@ class CodeGenerator {
...
@@ -30,7 +30,7 @@ class CodeGenerator {
public:
public:
CodeGenerator
();
CodeGenerator
();
std
::
string
Generate
(
std
::
string
func_name
,
std
::
string
dtype
,
std
::
string
Generate
(
std
::
string
func_name
,
const
std
::
vector
<
OperationExpression
>&
expressions
);
const
std
::
vector
<
OperationExpression
>&
expressions
);
std
::
string
Generate
(
SubGraph
*
subgraph
);
std
::
string
Generate
(
SubGraph
*
subgraph
);
...
@@ -42,16 +42,18 @@ class CodeGenerator {
...
@@ -42,16 +42,18 @@ class CodeGenerator {
const
std
::
vector
<
OperationExpression
>&
expressions
);
const
std
::
vector
<
OperationExpression
>&
expressions
);
std
::
set
<
int
>
DistilOutputIds
(
std
::
set
<
int
>
DistilOutputIds
(
const
std
::
vector
<
OperationExpression
>&
expressions
);
const
std
::
vector
<
OperationExpression
>&
expressions
);
std
::
unordered_map
<
int
,
std
::
string
>
DistilDtypes
(
const
std
::
vector
<
OperationExpression
>&
expressions
);
// we get the parameter list code for the expression information
// we get the parameter list code for the expression information
std
::
string
EmitParameters
(
const
std
::
set
<
int
>&
input_ids
,
std
::
string
EmitParameters
(
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
const
std
::
set
<
int
>&
output_ids
,
std
::
string
dtype
);
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
);
std
::
string
EmitComputeBody
(
std
::
string
EmitComputeBody
(
const
std
::
vector
<
OperationExpression
>&
expressions
,
const
std
::
vector
<
OperationExpression
>&
expressions
,
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
std
::
string
dtype
);
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
);
// Encode all var nodes in the subgraph with an unique number.
// Encode all var nodes in the subgraph with an unique number.
std
::
unordered_map
<
std
::
string
,
int
>
EncodeVarNodes
(
SubGraph
*
subgraph
);
std
::
unordered_map
<
std
::
string
,
int
>
EncodeVarNodes
(
SubGraph
*
subgraph
);
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
浏览文件 @
f0d193a2
...
@@ -50,10 +50,26 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
...
@@ -50,10 +50,26 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
return
sum_rhs
;
return
sum_rhs
;
}
}
// In order to avoid multiple __half2float function calls, we do this
// optimization
static
std
::
string
OptimzeFP16RHS
(
std
::
unordered_set
<
int
>*
used
,
const
int
index
,
const
std
::
vector
<
int
>&
input_ids
)
{
std
::
stringstream
ret
;
if
(
used
->
find
(
input_ids
[
index
])
==
used
->
end
())
{
ret
<<
"float half2fp32_"
+
TmpName
(
input_ids
[
index
])
+
" = __half2float("
+
TmpName
(
input_ids
[
index
])
+
");"
;
}
return
ret
.
str
();
}
std
::
string
OperationExpression
::
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
OperationExpression
::
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
*
half2fp32_statement
,
size_t
exprs_index
)
const
{
size_t
exprs_index
)
const
{
auto
rhs
=
OperationMap
::
Instance
().
Get
(
op_type_
).
exprs
[
exprs_index
];
auto
rhs
=
OperationMap
::
Instance
().
Get
(
op_type_
).
exprs
[
exprs_index
];
auto
num_operands
=
OperationMap
::
Instance
().
Get
(
op_type_
).
num_operands
;
auto
num_operands
=
OperationMap
::
Instance
().
Get
(
op_type_
).
num_operands
;
if
(
num_operands
==
-
1
)
{
if
(
num_operands
==
-
1
)
{
size_t
input_size
=
input_ids_
.
size
();
size_t
input_size
=
input_ids_
.
size
();
rhs
=
ExpandMultivariateTemplate
(
rhs
,
input_size
);
rhs
=
ExpandMultivariateTemplate
(
rhs
,
input_size
);
...
@@ -78,7 +94,16 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
...
@@ -78,7 +94,16 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Expected %d-th input id > 0 for operation < %s >. Received %d."
,
"Expected %d-th input id > 0 for operation < %s >. Received %d."
,
index
,
op_type_
,
input_ids_
[
index
]));
index
,
op_type_
,
input_ids_
[
index
]));
rhs
.
replace
(
pos
,
length
+
3
,
TmpName
(
input_ids_
[
index
]));
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we need
// to add general fp16 compute later.
std
::
string
var_name
;
if
(
rhs_type_
==
"float16"
)
{
half2fp32_statement
->
append
(
OptimzeFP16RHS
(
used
,
index
,
input_ids_
));
var_name
=
"half2fp32_"
+
TmpName
(
input_ids_
[
index
]);
}
else
{
var_name
=
TmpName
(
input_ids_
[
index
]);
}
rhs
.
replace
(
pos
,
length
+
3
,
var_name
);
used
->
insert
(
input_ids_
[
index
]);
used
->
insert
(
input_ids_
[
index
]);
}
}
}
}
...
@@ -87,7 +112,7 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
...
@@ -87,7 +112,7 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
std
::
string
OperationExpression
::
GetLHS
(
size_t
i
)
const
{
std
::
string
OperationExpression
::
GetLHS
(
size_t
i
)
const
{
std
::
stringstream
ret
;
std
::
stringstream
ret
;
ret
<<
TmpName
(
output_ids_
[
i
]);
ret
<<
lhs_type_
<<
" "
<<
TmpName
(
output_ids_
[
i
]);
return
ret
.
str
();
return
ret
.
str
();
}
}
...
@@ -98,15 +123,29 @@ bool OperationExpression::IsSupport() const {
...
@@ -98,15 +123,29 @@ bool OperationExpression::IsSupport() const {
// we Traverse the graph and get the group , all input id and output id is
// we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group
// unique for the node which belong the group
std
::
string
OperationExpression
::
GetExpression
(
std
::
string
OperationExpression
::
GetExpression
(
std
::
string
dtype
,
std
::
unordered_set
<
int
>*
used
)
const
{
std
::
unordered_set
<
int
>*
used
)
const
{
std
::
string
half2fp32_statement
;
std
::
stringstream
ret
;
std
::
stringstream
ret
;
if
(
IsSupport
())
{
if
(
IsSupport
())
{
for
(
size_t
i
=
0
;
i
<
output_ids_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
output_ids_
.
size
();
++
i
)
{
ret
<<
dtype
<<
" "
<<
GetLHS
(
i
)
<<
" = "
<<
GetRHS
(
used
,
i
)
<<
";"
;
std
::
string
cast_str
=
""
;
if
((
lhs_type_
==
rhs_type_
&&
rhs_type_
!=
"float16"
)
||
(
lhs_type_
!=
rhs_type_
&&
rhs_type_
==
"float16"
))
{
ret
<<
GetLHS
(
i
)
<<
" = "
<<
GetRHS
(
used
,
&
half2fp32_statement
,
i
)
<<
";"
;
}
else
{
if
((
lhs_type_
==
rhs_type_
&&
rhs_type_
==
"float16"
)
||
lhs_type_
==
"float16"
)
{
cast_str
=
"__float2half"
;
}
else
{
cast_str
=
"static_cast<"
+
lhs_type_
+
">"
;
}
}
ret
<<
GetLHS
(
i
)
<<
" = "
<<
cast_str
<<
"("
<<
GetRHS
(
used
,
&
half2fp32_statement
,
i
)
<<
");"
;
}
}
}
return
ret
.
str
();
}
return
half2fp32_statement
+
ret
.
str
();
}
}
}
// namespace fusion_group
}
// namespace fusion_group
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
浏览文件 @
f0d193a2
...
@@ -30,29 +30,41 @@ namespace fusion_group {
...
@@ -30,29 +30,41 @@ namespace fusion_group {
static
inline
std
::
string
ArgName
(
int
index
)
{
static
inline
std
::
string
ArgName
(
int
index
)
{
return
"arg"
+
std
::
to_string
(
index
);
return
"arg"
+
std
::
to_string
(
index
);
}
}
static
inline
std
::
string
TmpName
(
int
index
)
{
static
inline
std
::
string
TmpName
(
int
index
)
{
return
"tmp"
+
std
::
to_string
(
index
);
return
"tmp"
+
std
::
to_string
(
index
);
}
}
static
inline
std
::
string
VarName
(
int
index
)
{
return
"arg"
+
std
::
to_string
(
index
)
+
"[idx]"
;
}
class
OperationExpression
{
class
OperationExpression
{
public:
public:
explicit
OperationExpression
(
std
::
string
op_type
,
std
::
vector
<
int
>
input_ids
,
explicit
OperationExpression
(
std
::
string
op_type
,
std
::
vector
<
int
>
input_ids
,
std
::
vector
<
int
>
output_ids
)
std
::
vector
<
int
>
output_ids
,
:
op_type_
(
op_type
),
input_ids_
(
input_ids
),
output_ids_
(
output_ids
)
{}
std
::
string
rhs_type
,
std
::
string
lhs_type
)
:
op_type_
(
op_type
),
input_ids_
(
input_ids
),
output_ids_
(
output_ids
),
rhs_type_
(
rhs_type
),
lhs_type_
(
lhs_type
)
{}
std
::
string
GetOpType
()
const
{
return
op_type_
;
}
std
::
string
GetOpType
()
const
{
return
op_type_
;
}
std
::
vector
<
int
>
GetInputIds
()
const
{
return
input_ids_
;
}
std
::
vector
<
int
>
GetInputIds
()
const
{
return
input_ids_
;
}
std
::
vector
<
int
>
GetOutputIds
()
const
{
return
output_ids_
;
}
std
::
vector
<
int
>
GetOutputIds
()
const
{
return
output_ids_
;
}
std
::
string
GetRHSType
()
const
{
return
rhs_type_
;
}
std
::
string
GetLHSType
()
const
{
return
lhs_type_
;
}
// Check whether this operation type is supported in OperationMap.
// Check whether this operation type is supported in OperationMap.
bool
IsSupport
()
const
;
bool
IsSupport
()
const
;
std
::
string
GetExpression
(
std
::
string
dtype
,
std
::
string
GetExpression
(
std
::
unordered_set
<
int
>*
used
)
const
;
std
::
unordered_set
<
int
>*
used
)
const
;
private:
private:
// TODO(wangchao): make offset more flexible we add stride and basic offset
// TODO(wangchao): make offset more flexible we add stride and basic offset
std
::
string
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
*
half2fp32_statement
,
size_t
exprs_index
=
0
)
const
;
size_t
exprs_index
=
0
)
const
;
std
::
string
GetLHS
(
size_t
i
=
0
)
const
;
std
::
string
GetLHS
(
size_t
i
=
0
)
const
;
...
@@ -60,6 +72,8 @@ class OperationExpression {
...
@@ -60,6 +72,8 @@ class OperationExpression {
std
::
string
op_type_
;
std
::
string
op_type_
;
std
::
vector
<
int
>
input_ids_
;
std
::
vector
<
int
>
input_ids_
;
std
::
vector
<
int
>
output_ids_
;
std
::
vector
<
int
>
output_ids_
;
std
::
string
rhs_type_
;
std
::
string
lhs_type_
;
};
};
class
TemplateVariable
{
class
TemplateVariable
{
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
浏览文件 @
f0d193a2
...
@@ -288,7 +288,7 @@ void TestMain(std::string func_name,
...
@@ -288,7 +288,7 @@ void TestMain(std::string func_name,
std
::
string
dtype
)
{
std
::
string
dtype
)
{
fusion_group
::
OperationMap
::
Init
();
fusion_group
::
OperationMap
::
Init
();
fusion_group
::
CodeGenerator
code_generator
;
fusion_group
::
CodeGenerator
code_generator
;
std
::
string
code_str
=
code_generator
.
Generate
(
func_name
,
dtype
,
expressions
);
std
::
string
code_str
=
code_generator
.
Generate
(
func_name
,
expressions
);
VLOG
(
3
)
<<
code_str
;
VLOG
(
3
)
<<
code_str
;
LOG
(
INFO
)
<<
"dtype: "
<<
dtype
;
LOG
(
INFO
)
<<
"dtype: "
<<
dtype
;
...
@@ -297,7 +297,7 @@ void TestMain(std::string func_name,
...
@@ -297,7 +297,7 @@ void TestMain(std::string func_name,
}
}
void
TestMain
(
fusion_group
::
SubGraph
*
subgraph
,
std
::
vector
<
int
>
input_ids
,
void
TestMain
(
fusion_group
::
SubGraph
*
subgraph
,
std
::
vector
<
int
>
input_ids
,
std
::
vector
<
int
>
output_ids
)
{
std
::
vector
<
int
>
output_ids
,
std
::
string
dtype
)
{
fusion_group
::
OperationMap
::
Init
();
fusion_group
::
OperationMap
::
Init
();
fusion_group
::
CodeGenerator
code_generator
;
fusion_group
::
CodeGenerator
code_generator
;
std
::
string
code_str
=
code_generator
.
Generate
(
subgraph
);
std
::
string
code_str
=
code_generator
.
Generate
(
subgraph
);
...
@@ -307,26 +307,28 @@ void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids,
...
@@ -307,26 +307,28 @@ void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids,
std
::
vector
<
fusion_group
::
OperationExpression
>
expressions
=
std
::
vector
<
fusion_group
::
OperationExpression
>
expressions
=
code_generator
.
ConvertToExpressions
(
subgraph
);
code_generator
.
ConvertToExpressions
(
subgraph
);
LOG
(
INFO
)
<<
"dtype: "
<<
subgraph
->
GetDataType
();
TestElementwiseMain
(
subgraph
->
GetFuncName
(),
code_str
,
expressions
,
input_ids
,
TestElementwiseMain
(
subgraph
->
GetFuncName
(),
code_str
,
expressions
,
input_ids
,
output_ids
,
subgraph
->
GetDataType
()
);
output_ids
,
dtype
);
}
}
TEST
(
code_generator
,
elementwise
)
{
TEST
(
code_generator
,
elementwise
)
{
for
(
std
::
string
dtype
:
{
"float"
,
"float16"
})
{
// t2 = t0 * t1
// t2 = t0 * t1
// t4 = t2 + t3
// t4 = t2 + t3
// t6 = t4 - t5
// t6 = t4 - t5
// t7 = relu(t6)
// t7 = relu(t6)
// t8 = sigmoid(t7)
// t8 = sigmoid(t7)
fusion_group
::
OperationExpression
exp1
(
"elementwise_mul"
,
{
0
,
1
},
{
2
});
fusion_group
::
OperationExpression
exp1
(
"elementwise_mul"
,
{
0
,
1
},
{
2
},
fusion_group
::
OperationExpression
exp2
(
"elementwise_add"
,
{
2
,
3
},
{
4
});
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp3
(
"elementwise_sub"
,
{
4
,
5
},
{
6
});
fusion_group
::
OperationExpression
exp2
(
"elementwise_add"
,
{
2
,
3
},
{
4
},
fusion_group
::
OperationExpression
exp4
(
"relu"
,
{
6
},
{
7
});
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp5
(
"sigmoid"
,
{
7
},
{
8
});
fusion_group
::
OperationExpression
exp3
(
"elementwise_sub"
,
{
4
,
5
},
{
6
},
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp4
(
"relu"
,
{
6
},
{
7
},
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp5
(
"sigmoid"
,
{
7
},
{
8
},
dtype
,
dtype
);
std
::
vector
<
fusion_group
::
OperationExpression
>
expressions
=
{
std
::
vector
<
fusion_group
::
OperationExpression
>
expressions
=
{
exp1
,
exp2
,
exp3
,
exp4
,
exp5
};
exp1
,
exp2
,
exp3
,
exp4
,
exp5
};
for
(
std
::
string
dtype
:
{
"float"
,
"float16"
})
{
// Expressions:
// Expressions:
// Op(elementwise_mul), inputs:{0,1}, outputs:{2}
// Op(elementwise_mul), inputs:{0,1}, outputs:{2}
// Op(elementwise_add), inputs:{2,3}, outputs:{4}
// Op(elementwise_add), inputs:{2,3}, outputs:{4}
...
@@ -340,17 +342,18 @@ TEST(code_generator, elementwise) {
...
@@ -340,17 +342,18 @@ TEST(code_generator, elementwise) {
}
}
TEST
(
code_generator
,
elementwise_grad
)
{
TEST
(
code_generator
,
elementwise_grad
)
{
for
(
std
::
string
dtype
:
{
"float"
,
"float16"
})
{
// The var order: t0, t1, t2, t3, t0', t1', t2', t3'
// The var order: t0, t1, t2, t3, t0', t1', t2', t3'
// t2 = t0 * t1
// t2 = t0 * t1
// t3 = relu(t2)
// t3 = relu(t2)
// t2' = relu_grad(t2, t3, t3')
// t2' = relu_grad(t2, t3, t3')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
fusion_group
::
OperationExpression
exp1
(
"relu_grad"
,
{
-
1
,
3
,
7
},
{
6
});
fusion_group
::
OperationExpression
exp1
(
"relu_grad"
,
{
-
1
,
3
,
7
},
{
6
},
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp2
(
"elementwise_mul_grad"
,
{
0
,
1
,
2
,
6
},
fusion_group
::
OperationExpression
exp2
(
"elementwise_mul_grad"
,
{
0
,
1
,
2
,
6
},
{
4
,
5
}
);
{
4
,
5
},
dtype
,
dtype
);
std
::
vector
<
fusion_group
::
OperationExpression
>
expressions
=
{
exp1
,
exp2
};
std
::
vector
<
fusion_group
::
OperationExpression
>
expressions
=
{
exp1
,
exp2
};
for
(
std
::
string
dtype
:
{
"float"
,
"float16"
})
{
// Expressions:
// Expressions:
// Op(relu_grad), inputs:{2,3,7}, outputs:{6}
// Op(relu_grad), inputs:{2,3,7}, outputs:{6}
// Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5}
// Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5}
...
@@ -474,7 +477,7 @@ TEST(code_generator, subgraph) {
...
@@ -474,7 +477,7 @@ TEST(code_generator, subgraph) {
// Op(elementwise_add), inputs:{7,6}, outputs:{8}
// Op(elementwise_add), inputs:{7,6}, outputs:{8}
std
::
vector
<
int
>
input_ids
=
{
0
,
1
,
2
,
3
};
std
::
vector
<
int
>
input_ids
=
{
0
,
1
,
2
,
3
};
std
::
vector
<
int
>
output_ids
=
{
4
,
5
,
6
,
7
,
8
};
std
::
vector
<
int
>
output_ids
=
{
4
,
5
,
6
,
7
,
8
};
TestMain
(
&
subgraph
,
input_ids
,
output_ids
);
TestMain
(
&
subgraph
,
input_ids
,
output_ids
,
dtype
);
}
}
}
}
...
@@ -493,7 +496,7 @@ TEST(code_generator, subgraph_grad) {
...
@@ -493,7 +496,7 @@ TEST(code_generator, subgraph_grad) {
// Op(tanh_grad), inputs:{9,4,13}, outputs:{14}
// Op(tanh_grad), inputs:{9,4,13}, outputs:{14}
std
::
vector
<
int
>
input_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
std
::
vector
<
int
>
input_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
std
::
vector
<
int
>
output_ids
=
{
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
};
std
::
vector
<
int
>
output_ids
=
{
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
};
TestMain
(
&
subgraph
,
input_ids
,
output_ids
);
TestMain
(
&
subgraph
,
input_ids
,
output_ids
,
dtype
);
}
}
}
}
#endif
#endif
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
浏览文件 @
f0d193a2
...
@@ -60,6 +60,50 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
...
@@ -60,6 +60,50 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return
l
.
size
()
!=
0U
&&
r
.
size
()
!=
0U
&&
l
==
r
;
return
l
.
size
()
!=
0U
&&
r
.
size
()
!=
0U
&&
l
==
r
;
}
}
bool
GroupDetector
::
IsFusionGroupOp
(
const
Node
*
n
)
{
if
(
!
(
n
&&
n
->
IsOp
()
&&
n
->
Op
()))
return
false
;
bool
is_first
=
true
;
proto
::
VarType
::
Type
i_data_type
=
proto
::
VarType
::
FP32
;
proto
::
VarType
::
Type
o_data_type
=
proto
::
VarType
::
FP32
;
for
(
auto
*
i_node
:
n
->
inputs
)
{
if
(
!
i_node
->
Var
())
return
false
;
if
(
i_node
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
if
(
is_first
)
{
i_data_type
=
i_node
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
{
if
(
i_data_type
!=
i_node
->
Var
()
->
GetDataType
())
return
false
;
}
}
is_first
=
true
;
for
(
auto
*
o_node
:
n
->
outputs
)
{
if
(
!
o_node
->
Var
())
return
false
;
if
(
o_node
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
if
(
is_first
)
{
o_data_type
=
o_node
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
{
if
(
o_data_type
!=
o_node
->
Var
()
->
GetDataType
())
return
false
;
}
}
if
(
!
(
i_data_type
==
proto
::
VarType
::
FP32
||
i_data_type
==
proto
::
VarType
::
FP64
||
i_data_type
==
proto
::
VarType
::
FP16
)
||
!
(
o_data_type
==
proto
::
VarType
::
FP32
||
o_data_type
==
proto
::
VarType
::
FP64
||
o_data_type
==
proto
::
VarType
::
FP16
))
return
false
;
return
true
;
}
bool
ElementwiseGroupDetector
::
IsElementwiseOp
(
const
Node
*
n
)
{
bool
ElementwiseGroupDetector
::
IsElementwiseOp
(
const
Node
*
n
)
{
if
(
IsSpecifiedOp
(
GetElementwiseOpTypes
(),
n
))
{
if
(
IsSpecifiedOp
(
GetElementwiseOpTypes
(),
n
))
{
std
::
vector
<
int64_t
>
shape_0
;
std
::
vector
<
int64_t
>
shape_0
;
...
@@ -85,7 +129,9 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
...
@@ -85,7 +129,9 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
std
::
vector
<
std
::
vector
<
Node
*>>
ElementwiseGroupDetector
::
operator
()(
std
::
vector
<
std
::
vector
<
Node
*>>
ElementwiseGroupDetector
::
operator
()(
Graph
*
graph
)
{
Graph
*
graph
)
{
auto
teller
=
[
&
](
const
Node
*
n
)
->
bool
{
return
IsElementwiseOp
(
n
);
};
auto
teller
=
[
&
](
const
Node
*
n
)
->
bool
{
return
IsFusionGroupOp
(
n
)
&&
IsElementwiseOp
(
n
);
};
return
SubgraphDetector
(
graph
,
teller
)();
return
SubgraphDetector
(
graph
,
teller
)();
}
}
...
...
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h
浏览文件 @
f0d193a2
...
@@ -23,7 +23,12 @@ namespace framework {
...
@@ -23,7 +23,12 @@ namespace framework {
namespace
ir
{
namespace
ir
{
namespace
fusion_group
{
namespace
fusion_group
{
class
ElementwiseGroupDetector
{
class
GroupDetector
{
protected:
bool
IsFusionGroupOp
(
const
Node
*
n
);
};
class
ElementwiseGroupDetector
:
GroupDetector
{
public:
public:
std
::
vector
<
std
::
vector
<
Node
*>>
operator
()(
Graph
*
graph
);
std
::
vector
<
std
::
vector
<
Node
*>>
operator
()(
Graph
*
graph
);
...
...
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
浏览文件 @
f0d193a2
...
@@ -110,18 +110,25 @@ void FusionGroupPass::InsertFusionGroupOp(
...
@@ -110,18 +110,25 @@ void FusionGroupPass::InsertFusionGroupOp(
op_desc
.
SetType
(
"fusion_group"
);
op_desc
.
SetType
(
"fusion_group"
);
std
::
vector
<
std
::
string
>
input_names
;
std
::
vector
<
std
::
string
>
input_names
;
std
::
vector
<
std
::
string
>
inputs_data_types
;
for
(
auto
*
n
:
input_vars_of_subgraph
)
{
for
(
auto
*
n
:
input_vars_of_subgraph
)
{
input_names
.
push_back
(
n
->
Name
());
input_names
.
push_back
(
n
->
Name
());
inputs_data_types
.
push_back
(
DataTypeToString
(
n
->
Var
()
->
GetDataType
()));
external_nodes
.
insert
(
n
);
external_nodes
.
insert
(
n
);
}
}
op_desc
.
SetInput
(
"Inputs"
,
input_names
);
op_desc
.
SetInput
(
"Inputs"
,
input_names
);
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
outs_data_types
;
for
(
auto
*
n
:
output_vars_of_subgraph
)
{
for
(
auto
*
n
:
output_vars_of_subgraph
)
{
output_names
.
push_back
(
n
->
Name
());
output_names
.
push_back
(
n
->
Name
());
outs_data_types
.
push_back
(
DataTypeToString
(
n
->
Var
()
->
GetDataType
()));
external_nodes
.
insert
(
n
);
external_nodes
.
insert
(
n
);
}
}
op_desc
.
SetOutput
(
"Outs"
,
output_names
);
op_desc
.
SetOutput
(
"Outs"
,
output_names
);
op_desc
.
SetAttr
(
"inputs_data_type"
,
inputs_data_types
);
op_desc
.
SetAttr
(
"outs_data_type"
,
outs_data_types
);
op_desc
.
SetAttr
(
"type"
,
subgraph
->
GetType
());
op_desc
.
SetAttr
(
"type"
,
subgraph
->
GetType
());
op_desc
.
SetAttr
(
"func_name"
,
subgraph
->
GetFuncName
());
op_desc
.
SetAttr
(
"func_name"
,
subgraph
->
GetFuncName
());
op_desc
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
op_desc
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
...
@@ -131,6 +138,7 @@ void FusionGroupPass::InsertFusionGroupOp(
...
@@ -131,6 +138,7 @@ void FusionGroupPass::InsertFusionGroupOp(
for
(
auto
*
in
:
input_vars_of_subgraph
)
{
for
(
auto
*
in
:
input_vars_of_subgraph
)
{
IR_NODE_LINK_TO
(
in
,
fusion_group_node
);
IR_NODE_LINK_TO
(
in
,
fusion_group_node
);
}
}
for
(
auto
*
out
:
output_vars_of_subgraph
)
{
for
(
auto
*
out
:
output_vars_of_subgraph
)
{
IR_NODE_LINK_TO
(
fusion_group_node
,
out
);
IR_NODE_LINK_TO
(
fusion_group_node
,
out
);
}
}
...
...
paddle/fluid/framework/ir/fusion_group/operation.cc
浏览文件 @
f0d193a2
...
@@ -102,6 +102,13 @@ void OperationMap::InsertUnaryElementwiseOperations() {
...
@@ -102,6 +102,13 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// dx = dout * (1 - out * out)
// dx = dout * (1 - out * out)
insert_handler
(
"tanh"
,
"2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0"
,
insert_handler
(
"tanh"
,
"2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0"
,
{
"${2} * (1.0 - ${1} * ${1})"
});
{
"${2} * (1.0 - ${1} * ${1})"
});
// cast
// out = static_cast<T>(d)
// dx = static_cast<T>(d_out)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler
(
"cast"
,
"${0}"
,
{
"${0}"
});
}
}
void
OperationMap
::
InsertBinaryElementwiseOperations
()
{
void
OperationMap
::
InsertBinaryElementwiseOperations
()
{
...
@@ -158,10 +165,12 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
...
@@ -158,10 +165,12 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
std
::
vector
<
std
::
string
>
grad_exprs
)
{
std
::
vector
<
std
::
string
>
grad_exprs
)
{
int
type
=
0
;
int
type
=
0
;
int
num_oprands
=
-
1
;
int
num_oprands
=
-
1
;
// here ... represent the number of input is changed
Insert
(
type
,
num_oprands
,
op_type
,
expr
,
grad_exprs
,
{
"X"
},
{
"Out"
});
Insert
(
type
,
num_oprands
,
op_type
,
expr
,
grad_exprs
,
{
"X"
},
{
"Out"
});
};
};
// here [] represent the number of input is positive(>=0).
// if input list size of Sum Op is 3, It will expand as
// ${0} + ${1} + ${2}
insert_handler
(
"sum"
,
"${0}[ + ${?}]"
,
{});
insert_handler
(
"sum"
,
"${0}[ + ${?}]"
,
{});
}
}
...
...
paddle/fluid/framework/ir/fusion_group/subgraph.h
浏览文件 @
f0d193a2
...
@@ -49,7 +49,6 @@ class SubGraph {
...
@@ -49,7 +49,6 @@ class SubGraph {
}
}
}
}
}
}
ExtractDataType
();
}
}
bool
IsValid
(
int
min_subgraph_size
)
{
bool
IsValid
(
int
min_subgraph_size
)
{
...
@@ -61,11 +60,10 @@ class SubGraph {
...
@@ -61,11 +60,10 @@ class SubGraph {
return
false
;
return
false
;
}
}
return
ExtractDataType
()
;
return
true
;
}
}
int
GetType
()
const
{
return
type_
;
}
int
GetType
()
const
{
return
type_
;
}
std
::
string
GetDataType
()
const
{
return
data_type_
;
}
void
SetFuncName
(
std
::
string
func_name
)
{
func_name_
=
func_name
;
}
void
SetFuncName
(
std
::
string
func_name
)
{
func_name_
=
func_name
;
}
std
::
string
GetFuncName
()
const
{
return
func_name_
;
}
std
::
string
GetFuncName
()
const
{
return
func_name_
;
}
...
@@ -162,37 +160,6 @@ class SubGraph {
...
@@ -162,37 +160,6 @@ class SubGraph {
}
}
private:
private:
bool
ExtractDataType
()
{
bool
is_first
=
true
;
proto
::
VarType
::
Type
data_type
=
proto
::
VarType
::
FP32
;
for
(
auto
*
n
:
nodes_set_
)
{
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
if
(
n
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
// All var node in a subgraph should hold a LoDTensor.
return
false
;
}
if
(
is_first
)
{
data_type
=
n
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
if
(
n
->
Var
()
->
GetDataType
()
!=
data_type
)
{
// DataType of VarDesc in a subgraph is not the same.
return
false
;
}
}
}
if
(
data_type
==
proto
::
VarType
::
FP32
)
{
data_type_
=
"float"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP64
)
{
data_type_
=
"double"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP16
)
{
data_type_
=
"float16"
;
}
else
{
VLOG
(
2
)
<<
"Only support fp32, fp64 and fp16 in fusion_group."
;
return
false
;
}
return
true
;
}
void
TopologicalSort
()
{
void
TopologicalSort
()
{
if
(
!
is_sorted_
)
{
if
(
!
is_sorted_
)
{
std
::
unordered_map
<
Node
*
,
std
::
vector
<
Node
*>>
inputs_map
;
std
::
unordered_map
<
Node
*
,
std
::
vector
<
Node
*>>
inputs_map
;
...
...
paddle/fluid/operators/fused/fusion_group_op.cc
浏览文件 @
f0d193a2
...
@@ -21,7 +21,7 @@ class FusionGroupOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,7 @@ class FusionGroupOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
const
size_t
num_ins
=
ctx
->
Inputs
(
"Inputs"
).
size
();
const
size_t
num_ins
=
ctx
->
Inputs
(
"Inputs"
).
size
();
const
size_t
num_outs
=
ctx
->
Outputs
(
"Outs"
).
size
();
const
size_t
num_outs
=
ctx
->
Outputs
(
"Outs"
).
size
();
...
@@ -58,6 +58,13 @@ class FusionGroupOp : public framework::OperatorWithKernel {
...
@@ -58,6 +58,13 @@ class FusionGroupOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"Inputs"
,
/*->*/
"Outs"
,
0
,
j
);
ctx
->
ShareLoD
(
"Inputs"
,
/*->*/
"Outs"
,
0
,
j
);
}
}
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
proto
::
VarType
::
FP32
,
platform
::
CUDAPlace
(
0
));
};
};
};
class
FusionGroupOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
FusionGroupOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -69,6 +76,12 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -69,6 +76,12 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Outs"
,
AddOutput
(
"Outs"
,
"(std::vector<LoDTensor>) The outputs of fusion_group op."
)
"(std::vector<LoDTensor>) The outputs of fusion_group op."
)
.
AsDuplicable
();
.
AsDuplicable
();
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"outs_data_type"
,
"The data type of Outputs in fusion_group op."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"inputs_data_type"
,
"The data type of Inputs in fusion_group op."
)
.
SetDefault
({});
AddAttr
<
int
>
(
"type"
,
"Fusion type."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"type"
,
"Fusion type."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"func_name"
,
"Name of the generated functions."
)
AddAttr
<
std
::
string
>
(
"func_name"
,
"Name of the generated functions."
)
.
SetDefault
(
""
);
.
SetDefault
(
""
);
...
...
paddle/fluid/operators/fused/fusion_group_op.h
浏览文件 @
f0d193a2
...
@@ -22,6 +22,20 @@ limitations under the License. */
...
@@ -22,6 +22,20 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
static
void
MutableMultiTypeData
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>*
var
,
const
std
::
vector
<
std
::
string
>&
data_type
,
const
platform
::
Place
&
place
)
{
for
(
size_t
i
=
0
;
i
<
(
*
var
).
size
();
i
++
)
{
if
(
data_type
[
i
]
==
"float"
)
{
(
*
var
)[
i
]
->
mutable_data
<
float
>
(
place
);
}
else
if
(
data_type
[
i
]
==
"double"
)
{
(
*
var
)[
i
]
->
mutable_data
<
double
>
(
place
);
}
else
if
(
data_type
[
i
]
==
"::paddle::platform::float16"
)
{
(
*
var
)[
i
]
->
mutable_data
<
paddle
::
platform
::
float16
>
(
place
);
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FusionGroupKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FusionGroupKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -29,14 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> {
...
@@ -29,14 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> {
auto
ins
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Inputs"
);
auto
ins
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Inputs"
);
auto
outs
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"Outs"
);
auto
outs
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"Outs"
);
int
type
=
ctx
.
Attr
<
int
>
(
"type"
);
int
type
=
ctx
.
Attr
<
int
>
(
"type"
);
auto
outs_type
=
ctx
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"outs_data_type"
);
auto
inputs_type
=
ctx
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"inputs_data_type"
);
size_t
num_ins
=
ins
.
size
();
size_t
num_ins
=
ins
.
size
();
size_t
num_outs
=
outs
.
size
();
size_t
num_outs
=
outs
.
size
();
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
for
(
size_t
i
=
0
;
i
<
num_outs
;
++
i
)
{
outs
[
i
]
->
mutable_data
<
T
>
(
place
);
MutableMultiTypeData
(
&
outs
,
outs_type
,
place
);
}
std
::
string
func_name
=
ctx
.
Attr
<
std
::
string
>
(
"func_name"
);
std
::
string
func_name
=
ctx
.
Attr
<
std
::
string
>
(
"func_name"
);
platform
::
DeviceCode
*
dev_code
=
platform
::
DeviceCode
*
dev_code
=
...
@@ -47,13 +62,25 @@ class FusionGroupKernel : public framework::OpKernel<T> {
...
@@ -47,13 +62,25 @@ class FusionGroupKernel : public framework::OpKernel<T> {
size_t
n
=
ins
[
0
]
->
numel
();
size_t
n
=
ins
[
0
]
->
numel
();
std
::
vector
<
void
*>
args
;
std
::
vector
<
void
*>
args
;
args
.
push_back
(
&
n
);
args
.
push_back
(
&
n
);
std
::
vector
<
const
T
*>
ptrs
(
num_ins
+
num_outs
);
std
::
vector
<
const
void
*>
ptrs
(
num_ins
+
num_outs
);
for
(
size_t
i
=
0
;
i
<
num_ins
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_ins
;
++
i
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
T
>
();
if
(
inputs_type
[
i
]
==
"::paddle::platform::float16"
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
paddle
::
platform
::
float16
>
();
}
else
if
(
inputs_type
[
i
]
==
"double"
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
double
>
();
}
else
if
(
inputs_type
[
i
]
==
"float"
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
float
>
();
}
args
.
push_back
(
&
ptrs
[
i
]);
args
.
push_back
(
&
ptrs
[
i
]);
}
}
for
(
size_t
j
=
0
;
j
<
num_outs
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
num_outs
;
++
j
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
T
>
();
if
(
outs_type
[
j
]
==
"::paddle::platform::float16"
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
paddle
::
platform
::
float16
>
();
}
else
if
(
outs_type
[
j
]
==
"double"
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
double
>
();
}
else
if
(
outs_type
[
j
]
==
"float"
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
float
>
();
}
args
.
push_back
(
&
ptrs
[
num_ins
+
j
]);
args
.
push_back
(
&
ptrs
[
num_ins
+
j
]);
}
}
dev_code
->
Launch
(
n
,
&
args
);
dev_code
->
Launch
(
n
,
&
args
);
...
...
paddle/fluid/operators/fused/fusion_group_op_test.cc
浏览文件 @
f0d193a2
...
@@ -57,7 +57,8 @@ framework::OpDesc* CreateFusionGroupOp(
...
@@ -57,7 +57,8 @@ framework::OpDesc* CreateFusionGroupOp(
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
std
::
string
func_name
)
{
const
std
::
vector
<
std
::
string
>&
inputs_data_type
,
const
std
::
vector
<
std
::
string
>&
outs_data_type
,
std
::
string
func_name
)
{
EXPECT_EQ
(
input_names
.
size
(),
input_shapes
.
size
());
EXPECT_EQ
(
input_names
.
size
(),
input_shapes
.
size
());
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
...
@@ -76,6 +77,8 @@ framework::OpDesc* CreateFusionGroupOp(
...
@@ -76,6 +77,8 @@ framework::OpDesc* CreateFusionGroupOp(
op
->
SetType
(
"fusion_group"
);
op
->
SetType
(
"fusion_group"
);
op
->
SetInput
(
"Inputs"
,
input_names
);
op
->
SetInput
(
"Inputs"
,
input_names
);
op
->
SetOutput
(
"Outs"
,
output_names
);
op
->
SetOutput
(
"Outs"
,
output_names
);
op
->
SetAttr
(
"inputs_data_type"
,
inputs_data_type
);
op
->
SetAttr
(
"outs_data_type"
,
outs_data_type
);
op
->
SetAttr
(
"type"
,
type
);
op
->
SetAttr
(
"type"
,
type
);
op
->
SetAttr
(
"func_name"
,
func_name
);
op
->
SetAttr
(
"func_name"
,
func_name
);
op
->
SetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
op
->
SetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
...
@@ -130,6 +133,8 @@ void CheckOutputs(framework::Scope* scope,
...
@@ -130,6 +133,8 @@ void CheckOutputs(framework::Scope* scope,
void
TestMain
(
const
std
::
vector
<
std
::
string
>&
input_names
,
void
TestMain
(
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
const
std
::
vector
<
std
::
string
>&
inputs_data_type
,
const
std
::
vector
<
std
::
string
>&
outs_data_type
,
std
::
string
func_name
,
std
::
string
cuda_kernel_str
,
std
::
string
func_name
,
std
::
string
cuda_kernel_str
,
CPUKernelFunc
cpu_kernel_func
)
{
CPUKernelFunc
cpu_kernel_func
)
{
// Compile the device code
// Compile the device code
...
@@ -139,8 +144,9 @@ void TestMain(const std::vector<std::string>& input_names,
...
@@ -139,8 +144,9 @@ void TestMain(const std::vector<std::string>& input_names,
// Create a ProgramDesc that has a fusion_group_op.
// Create a ProgramDesc that has a fusion_group_op.
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
framework
::
OpDesc
*
op_desc
=
CreateFusionGroupOp
(
framework
::
OpDesc
*
op_desc
=
&
program
,
input_names
,
input_shapes
,
output_names
,
type
,
func_name
);
CreateFusionGroupOp
(
&
program
,
input_names
,
input_shapes
,
output_names
,
type
,
inputs_data_type
,
outs_data_type
,
func_name
);
auto
fusion_group_op
=
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
auto
fusion_group_op
=
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
framework
::
Scope
scope
;
framework
::
Scope
scope
;
...
@@ -210,8 +216,11 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
...
@@ -210,8 +216,11 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
}
}
};
};
TestMain
(
input_names
,
input_shapes
,
output_names
,
0
,
std
::
vector
<
std
::
string
>
inputs_data_type
(
input_names
.
size
(),
"float"
);
"elementwise_cuda_kernel_0"
,
kernel
,
elementwise_cpu_kernel_0
);
std
::
vector
<
std
::
string
>
outs_data_type
(
output_names
.
size
(),
"float"
);
TestMain
(
input_names
,
input_shapes
,
output_names
,
0
,
inputs_data_type
,
outs_data_type
,
"elementwise_cuda_kernel_0"
,
kernel
,
elementwise_cpu_kernel_0
);
}
}
}
// namespace operators
}
// namespace operators
...
...
python/paddle/fluid/tests/unittests/ir/pass_test.py
浏览文件 @
f0d193a2
...
@@ -142,8 +142,8 @@ class PassTest(unittest.TestCase):
...
@@ -142,8 +142,8 @@ class PassTest(unittest.TestCase):
self
.
assertTrue
(
self
.
assertTrue
(
np
.
allclose
(
np
.
allclose
(
outs_opt
[
i
],
outs
[
i
],
atol
=
atol
),
outs_opt
[
i
],
outs
[
i
],
atol
=
atol
),
"Output < {} > has diff at {}
"
.
format
(
self
.
fetch_list
[
i
].
name
,
"Output < {} > has diff at {}
, expected {} but got {}"
.
format
(
str
(
place
)
))
self
.
fetch_list
[
i
].
name
,
str
(
place
),
outs_opt
[
i
],
outs
[
i
]
))
def
_check_fused_ops
(
self
,
program
):
def
_check_fused_ops
(
self
,
program
):
'''
'''
...
...
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
浏览文件 @
f0d193a2
...
@@ -125,17 +125,15 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
...
@@ -125,17 +125,15 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
fluid
.
data
(
fluid
.
data
(
name
=
"data2"
,
shape
=
[
128
,
128
],
dtype
=
dtype
))
name
=
"data2"
,
shape
=
[
128
,
128
],
dtype
=
dtype
))
# subgraph with only 1 op node
tmp_0
=
self
.
feed_vars
[
0
]
*
self
.
feed_vars
[
1
]
tmp_0
=
self
.
feed_vars
[
0
]
*
self
.
feed_vars
[
1
]
tmp_1
=
layers
.
mul
(
tmp_0
,
self
.
feed_vars
[
2
])
tmp_1
=
layers
.
mul
(
tmp_0
,
self
.
feed_vars
[
2
])
tmp_2
=
layers
.
cast
(
tmp_0
,
dtype
=
"float16"
)
tmp_3
=
layers
.
cast
(
tmp_1
,
dtype
=
"float16"
)
tmp_3
=
layers
.
cast
(
tmp_1
,
dtype
=
"float16"
)
# subgraph with 2 op nodes
tmp_2
=
layers
.
cast
(
tmp_0
,
dtype
=
"float16"
)
tmp_4
=
layers
.
relu
(
tmp_2
+
tmp_3
)
tmp_4
=
layers
.
relu
(
tmp_2
+
tmp_3
)
tmp_5
=
layers
.
cast
(
tmp_4
,
dtype
=
dtype
)
tmp_5
=
layers
.
cast
(
tmp_4
,
dtype
=
dtype
)
self
.
fetch_list
=
[
tmp_5
]
self
.
fetch_list
=
[
tmp_
0
,
tmp_1
,
tmp_2
,
tmp_3
,
tmp_4
,
tmp_
5
]
self
.
num_fused_ops
=
1
self
.
num_fused_ops
=
2
class
FusionGroupPassSumTest
(
FusionGroupPassTest
):
class
FusionGroupPassSumTest
(
FusionGroupPassTest
):
...
@@ -147,9 +145,28 @@ class FusionGroupPassSumTest(FusionGroupPassTest):
...
@@ -147,9 +145,28 @@ class FusionGroupPassSumTest(FusionGroupPassTest):
tmp_1
=
layers
.
sum
([
tmp_0
,
self
.
feed_vars
[
2
],
self
.
feed_vars
[
3
]])
tmp_1
=
layers
.
sum
([
tmp_0
,
self
.
feed_vars
[
2
],
self
.
feed_vars
[
3
]])
tmp_2
=
layers
.
sum
([
tmp_1
,
self
.
feed_vars
[
4
]])
tmp_2
=
layers
.
sum
([
tmp_1
,
self
.
feed_vars
[
4
]])
self
.
fetch_list
=
[
tmp_0
,
tmp_1
]
self
.
fetch_list
=
[
tmp_0
,
tmp_1
,
tmp_2
]
self
.
num_fused_ops
=
1
class
FusionGroupPassCastTest
(
FusionGroupPassTest
):
def
build_program
(
self
,
dtype
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
2
,
2
],
dtype
,
2
)
tmp_0
=
layers
.
elementwise_add
(
self
.
feed_vars
[
0
],
self
.
feed_vars
[
1
])
tmp_1
=
layers
.
cast
(
tmp_0
,
dtype
=
"double"
)
tmp_2
=
layers
.
cast
(
tmp_1
,
dtype
=
"float32"
)
self
.
fetch_list
=
[
tmp_0
,
tmp_1
,
tmp_2
]
self
.
num_fused_ops
=
1
self
.
num_fused_ops
=
1
def
setUp
(
self
):
self
.
build_program
(
"float64"
)
self
.
feeds
=
self
.
_feed_random_data
(
self
.
feed_vars
)
self
.
pass_names
=
"fusion_group_pass"
self
.
fused_op_type
=
"fusion_group"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录