Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f0d193a2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f0d193a2
编写于
3月 12, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
3月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cast fusion for fusion group (#22876)
* add support for expression type convert and add cast Op support in fusion group
上级
29a7a52d
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
337 addition
and
135 deletion
+337
-135
paddle/fluid/framework/ir/fusion_group/code_generator.cc
paddle/fluid/framework/ir/fusion_group/code_generator.cc
+79
-36
paddle/fluid/framework/ir/fusion_group/code_generator.h
paddle/fluid/framework/ir/fusion_group/code_generator.h
+5
-3
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
.../fluid/framework/ir/fusion_group/code_generator_helper.cc
+45
-6
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
...e/fluid/framework/ir/fusion_group/code_generator_helper.h
+18
-4
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
.../fluid/framework/ir/fusion_group/code_generator_tester.cc
+32
-29
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
...d/framework/ir/fusion_group/elementwise_group_detector.cc
+47
-1
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h
...id/framework/ir/fusion_group/elementwise_group_detector.h
+6
-1
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
+8
-0
paddle/fluid/framework/ir/fusion_group/operation.cc
paddle/fluid/framework/ir/fusion_group/operation.cc
+10
-1
paddle/fluid/framework/ir/fusion_group/subgraph.h
paddle/fluid/framework/ir/fusion_group/subgraph.h
+1
-34
paddle/fluid/operators/fused/fusion_group_op.cc
paddle/fluid/operators/fused/fusion_group_op.cc
+14
-1
paddle/fluid/operators/fused/fusion_group_op.h
paddle/fluid/operators/fused/fusion_group_op.h
+33
-6
paddle/fluid/operators/fused/fusion_group_op_test.cc
paddle/fluid/operators/fused/fusion_group_op_test.cc
+14
-5
python/paddle/fluid/tests/unittests/ir/pass_test.py
python/paddle/fluid/tests/unittests/ir/pass_test.py
+2
-2
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
...dle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
+23
-6
未找到文件。
paddle/fluid/framework/ir/fusion_group/code_generator.cc
浏览文件 @
f0d193a2
...
...
@@ -24,6 +24,21 @@ namespace framework {
namespace
ir
{
namespace
fusion_group
{
std
::
string
ExtractDataType
(
const
std
::
vector
<
Node
*>
nodes
)
{
std
::
string
dtype_str
=
"float"
;
auto
data_type
=
nodes
.
back
()
->
Var
()
->
GetDataType
();
if
(
data_type
==
proto
::
VarType
::
FP32
)
{
dtype_str
=
"float"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP64
)
{
dtype_str
=
"double"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP16
)
{
dtype_str
=
"float16"
;
}
return
dtype_str
;
}
CodeGenerator
::
CodeGenerator
()
{
// Only support elementwise operations now.
code_templates_
.
resize
(
1
);
...
...
@@ -34,8 +49,7 @@ CodeGenerator::CodeGenerator() {
std
::
string
CodeGenerator
::
Generate
(
SubGraph
*
subgraph
)
{
std
::
vector
<
OperationExpression
>
expressions
=
ConvertToExpressions
(
subgraph
);
return
Generate
(
subgraph
->
GetFuncName
(),
subgraph
->
GetDataType
(),
expressions
);
return
Generate
(
subgraph
->
GetFuncName
(),
expressions
);
}
static
bool
HasInput
(
Node
*
n
,
std
::
string
name
)
{
...
...
@@ -95,8 +109,11 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
"Output(%s) of operation %s is not set."
,
name
,
op
->
Type
()));
output_ids
.
push_back
(
var_ids
[
op
->
Output
(
name
)[
0
]]);
}
expressions
.
push_back
(
OperationExpression
(
node
->
Name
(),
input_ids
,
output_ids
));
std
::
string
lhs_type
=
ExtractDataType
(
node
->
outputs
);
std
::
string
rhs_type
=
ExtractDataType
(
node
->
inputs
);
expressions
.
emplace_back
(
OperationExpression
(
node
->
Name
(),
input_ids
,
output_ids
,
rhs_type
,
lhs_type
));
}
}
return
expressions
;
...
...
@@ -105,25 +122,32 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector.
std
::
string
CodeGenerator
::
Generate
(
std
::
string
func_name
,
std
::
string
dtype
,
std
::
string
func_name
,
const
std
::
vector
<
OperationExpression
>&
expressions
)
{
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std
::
set
<
int
>
input_ids
=
DistilInputIds
(
expressions
);
std
::
set
<
int
>
output_ids
=
DistilOutputIds
(
expressions
);
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
=
DistilDtypes
(
expressions
);
TemplateVariable
template_var
;
template_var
.
Add
(
"func_name"
,
func_name
);
template_var
.
Add
(
"parameters"
,
EmitParameters
(
input_ids
,
output_ids
,
dtype
));
template_var
.
Add
(
"parameters"
,
EmitParameters
(
input_ids
,
output_ids
,
dtype
s
));
template_var
.
Add
(
"compute_body"
,
EmitComputeBody
(
expressions
,
input_ids
,
output_ids
,
dtype
));
EmitComputeBody
(
expressions
,
input_ids
,
output_ids
,
dtype
s
));
std
::
string
predefined_cuda_functions
;
if
(
dtype
==
"float"
)
{
predefined_cuda_functions
=
predefined_cuda_functions_fp32
;
}
else
if
(
dtype
==
"double"
)
{
predefined_cuda_functions
=
predefined_cuda_functions_fp64
;
}
else
if
(
dtype
==
"float16"
)
{
predefined_cuda_functions
=
predefined_cuda_functions_fp16
;
std
::
set
<
std
::
string
>
all_dtype
;
for
(
const
auto
&
type
:
dtypes
)
{
all_dtype
.
insert
(
type
.
second
);
}
std
::
string
predefined_cuda_functions
=
""
;
if
(
all_dtype
.
find
(
"float"
)
!=
all_dtype
.
end
()
&&
all_dtype
.
find
(
"float16"
)
==
all_dtype
.
end
())
{
predefined_cuda_functions
+=
predefined_cuda_functions_fp32
;
}
if
(
all_dtype
.
find
(
"double"
)
!=
all_dtype
.
end
())
{
predefined_cuda_functions
+=
predefined_cuda_functions_fp64
;
}
if
(
all_dtype
.
find
(
"float16"
)
!=
all_dtype
.
end
())
{
predefined_cuda_functions
+=
predefined_cuda_functions_fp16
;
}
return
predefined_cuda_functions
+
code_templates_
[
0
].
Format
(
template_var
);
}
...
...
@@ -154,10 +178,40 @@ std::set<int> CodeGenerator::DistilOutputIds(
return
output_ids
;
}
std
::
unordered_map
<
int
,
std
::
string
>
CodeGenerator
::
DistilDtypes
(
const
std
::
vector
<
OperationExpression
>&
expressions
)
{
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
;
for
(
const
auto
&
expression
:
expressions
)
{
for
(
auto
id
:
expression
.
GetInputIds
())
{
auto
dtype
=
expression
.
GetRHSType
();
if
(
dtypes
.
find
(
id
)
==
dtypes
.
end
())
{
dtypes
[
id
]
=
dtype
;
}
else
{
PADDLE_ENFORCE_EQ
(
dtypes
[
id
],
dtype
,
platform
::
errors
::
PreconditionNotMet
(
"In fusion group, Same Node id must have same date type"
));
}
}
for
(
auto
id
:
expression
.
GetOutputIds
())
{
auto
dtype
=
expression
.
GetLHSType
();
if
(
dtypes
.
find
(
id
)
==
dtypes
.
end
())
{
dtypes
[
id
]
=
dtype
;
}
else
{
PADDLE_ENFORCE_EQ
(
dtypes
[
id
],
dtype
,
platform
::
errors
::
PreconditionNotMet
(
"In fusion group, Same Node id must have same date type"
));
}
}
}
return
dtypes
;
}
// we get the parameter list code for the expression information
std
::
string
CodeGenerator
::
EmitParameters
(
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
std
::
string
dtype
)
{
std
::
string
CodeGenerator
::
EmitParameters
(
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
)
{
std
::
stringstream
ret
;
ret
<<
"int N, "
;
...
...
@@ -165,13 +219,13 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
// from the input list.
for
(
auto
id
:
input_ids
)
{
if
(
output_ids
.
find
(
id
)
==
output_ids
.
end
())
{
ret
<<
dtype
<<
"* "
<<
ArgName
(
id
)
<<
", "
;
ret
<<
dtype
s
[
id
]
<<
"* "
<<
ArgName
(
id
)
<<
", "
;
}
}
size_t
index
=
0
;
for
(
auto
id
:
output_ids
)
{
ret
<<
dtype
<<
"* "
<<
ArgName
(
id
);
ret
<<
dtype
s
[
id
]
<<
"* "
<<
ArgName
(
id
);
if
(
index
!=
output_ids
.
size
()
-
1
)
{
ret
<<
", "
;
}
...
...
@@ -184,13 +238,12 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
std
::
string
CodeGenerator
::
EmitComputeBody
(
const
std
::
vector
<
OperationExpression
>&
expressions
,
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
std
::
string
dtype
)
{
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
)
{
std
::
ostringstream
compute
;
std
::
unordered_set
<
int
>
used
;
std
::
string
compute_dtype
=
(
dtype
==
"float16"
)
?
"float"
:
dtype
;
for
(
size_t
i
=
0
;
i
<
expressions
.
size
();
i
++
)
{
VLOG
(
3
)
<<
DebugString
(
expressions
[
i
]);
compute
<<
expressions
[
i
].
GetExpression
(
compute_dtype
,
&
used
);
compute
<<
expressions
[
i
].
GetExpression
(
&
used
);
}
// Load input to temporal variables.
...
...
@@ -198,23 +251,13 @@ std::string CodeGenerator::EmitComputeBody(
for
(
auto
id
:
input_ids
)
{
if
(
output_ids
.
find
(
id
)
==
output_ids
.
end
()
&&
used
.
find
(
id
)
!=
used
.
end
())
{
if
(
dtype
==
"float16"
)
{
load
<<
"float "
<<
TmpName
(
id
)
<<
" = __half2float("
<<
ArgName
(
id
)
<<
"[idx]);"
;
}
else
{
load
<<
dtype
<<
" "
<<
TmpName
(
id
)
<<
" = "
<<
ArgName
(
id
)
<<
"[idx];"
;
}
load
<<
dtypes
[
id
]
<<
" "
<<
TmpName
(
id
)
<<
" = "
<<
VarName
(
id
)
<<
";"
;
}
}
// Store temporal variables to memory.
std
::
ostringstream
store
;
for
(
auto
id
:
output_ids
)
{
if
(
dtype
==
"float16"
)
{
store
<<
ArgName
(
id
)
<<
"[idx] = __float2half("
<<
TmpName
(
id
)
<<
");"
;
}
else
{
store
<<
ArgName
(
id
)
<<
"[idx] = "
<<
TmpName
(
id
)
<<
";"
;
}
store
<<
VarName
(
id
)
<<
" = "
<<
TmpName
(
id
)
<<
";"
;
}
return
load
.
str
()
+
compute
.
str
()
+
store
.
str
();
...
...
paddle/fluid/framework/ir/fusion_group/code_generator.h
浏览文件 @
f0d193a2
...
...
@@ -30,7 +30,7 @@ class CodeGenerator {
public:
CodeGenerator
();
std
::
string
Generate
(
std
::
string
func_name
,
std
::
string
dtype
,
std
::
string
Generate
(
std
::
string
func_name
,
const
std
::
vector
<
OperationExpression
>&
expressions
);
std
::
string
Generate
(
SubGraph
*
subgraph
);
...
...
@@ -42,16 +42,18 @@ class CodeGenerator {
const
std
::
vector
<
OperationExpression
>&
expressions
);
std
::
set
<
int
>
DistilOutputIds
(
const
std
::
vector
<
OperationExpression
>&
expressions
);
std
::
unordered_map
<
int
,
std
::
string
>
DistilDtypes
(
const
std
::
vector
<
OperationExpression
>&
expressions
);
// we get the parameter list code for the expression information
std
::
string
EmitParameters
(
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
std
::
string
dtype
);
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
);
std
::
string
EmitComputeBody
(
const
std
::
vector
<
OperationExpression
>&
expressions
,
const
std
::
set
<
int
>&
input_ids
,
const
std
::
set
<
int
>&
output_ids
,
std
::
string
dtype
);
std
::
unordered_map
<
int
,
std
::
string
>
dtypes
);
// Encode all var nodes in the subgraph with an unique number.
std
::
unordered_map
<
std
::
string
,
int
>
EncodeVarNodes
(
SubGraph
*
subgraph
);
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
浏览文件 @
f0d193a2
...
...
@@ -50,10 +50,26 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
return
sum_rhs
;
}
// In order to avoid multiple __half2float function calls, we do this
// optimization
static
std
::
string
OptimzeFP16RHS
(
std
::
unordered_set
<
int
>*
used
,
const
int
index
,
const
std
::
vector
<
int
>&
input_ids
)
{
std
::
stringstream
ret
;
if
(
used
->
find
(
input_ids
[
index
])
==
used
->
end
())
{
ret
<<
"float half2fp32_"
+
TmpName
(
input_ids
[
index
])
+
" = __half2float("
+
TmpName
(
input_ids
[
index
])
+
");"
;
}
return
ret
.
str
();
}
std
::
string
OperationExpression
::
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
*
half2fp32_statement
,
size_t
exprs_index
)
const
{
auto
rhs
=
OperationMap
::
Instance
().
Get
(
op_type_
).
exprs
[
exprs_index
];
auto
num_operands
=
OperationMap
::
Instance
().
Get
(
op_type_
).
num_operands
;
if
(
num_operands
==
-
1
)
{
size_t
input_size
=
input_ids_
.
size
();
rhs
=
ExpandMultivariateTemplate
(
rhs
,
input_size
);
...
...
@@ -78,7 +94,16 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
platform
::
errors
::
InvalidArgument
(
"Expected %d-th input id > 0 for operation < %s >. Received %d."
,
index
,
op_type_
,
input_ids_
[
index
]));
rhs
.
replace
(
pos
,
length
+
3
,
TmpName
(
input_ids_
[
index
]));
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we need
// to add general fp16 compute later.
std
::
string
var_name
;
if
(
rhs_type_
==
"float16"
)
{
half2fp32_statement
->
append
(
OptimzeFP16RHS
(
used
,
index
,
input_ids_
));
var_name
=
"half2fp32_"
+
TmpName
(
input_ids_
[
index
]);
}
else
{
var_name
=
TmpName
(
input_ids_
[
index
]);
}
rhs
.
replace
(
pos
,
length
+
3
,
var_name
);
used
->
insert
(
input_ids_
[
index
]);
}
}
...
...
@@ -87,7 +112,7 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
std
::
string
OperationExpression
::
GetLHS
(
size_t
i
)
const
{
std
::
stringstream
ret
;
ret
<<
TmpName
(
output_ids_
[
i
]);
ret
<<
lhs_type_
<<
" "
<<
TmpName
(
output_ids_
[
i
]);
return
ret
.
str
();
}
...
...
@@ -98,15 +123,29 @@ bool OperationExpression::IsSupport() const {
// we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group
std
::
string
OperationExpression
::
GetExpression
(
std
::
string
dtype
,
std
::
unordered_set
<
int
>*
used
)
const
{
std
::
unordered_set
<
int
>*
used
)
const
{
std
::
string
half2fp32_statement
;
std
::
stringstream
ret
;
if
(
IsSupport
())
{
for
(
size_t
i
=
0
;
i
<
output_ids_
.
size
();
++
i
)
{
ret
<<
dtype
<<
" "
<<
GetLHS
(
i
)
<<
" = "
<<
GetRHS
(
used
,
i
)
<<
";"
;
std
::
string
cast_str
=
""
;
if
((
lhs_type_
==
rhs_type_
&&
rhs_type_
!=
"float16"
)
||
(
lhs_type_
!=
rhs_type_
&&
rhs_type_
==
"float16"
))
{
ret
<<
GetLHS
(
i
)
<<
" = "
<<
GetRHS
(
used
,
&
half2fp32_statement
,
i
)
<<
";"
;
}
else
{
if
((
lhs_type_
==
rhs_type_
&&
rhs_type_
==
"float16"
)
||
lhs_type_
==
"float16"
)
{
cast_str
=
"__float2half"
;
}
else
{
cast_str
=
"static_cast<"
+
lhs_type_
+
">"
;
}
ret
<<
GetLHS
(
i
)
<<
" = "
<<
cast_str
<<
"("
<<
GetRHS
(
used
,
&
half2fp32_statement
,
i
)
<<
");"
;
}
return
ret
.
str
();
}
}
return
half2fp32_statement
+
ret
.
str
();
}
}
// namespace fusion_group
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
浏览文件 @
f0d193a2
...
...
@@ -30,29 +30,41 @@ namespace fusion_group {
static
inline
std
::
string
ArgName
(
int
index
)
{
return
"arg"
+
std
::
to_string
(
index
);
}
static
inline
std
::
string
TmpName
(
int
index
)
{
return
"tmp"
+
std
::
to_string
(
index
);
}
static
inline
std
::
string
VarName
(
int
index
)
{
return
"arg"
+
std
::
to_string
(
index
)
+
"[idx]"
;
}
class
OperationExpression
{
public:
explicit
OperationExpression
(
std
::
string
op_type
,
std
::
vector
<
int
>
input_ids
,
std
::
vector
<
int
>
output_ids
)
:
op_type_
(
op_type
),
input_ids_
(
input_ids
),
output_ids_
(
output_ids
)
{}
std
::
vector
<
int
>
output_ids
,
std
::
string
rhs_type
,
std
::
string
lhs_type
)
:
op_type_
(
op_type
),
input_ids_
(
input_ids
),
output_ids_
(
output_ids
),
rhs_type_
(
rhs_type
),
lhs_type_
(
lhs_type
)
{}
std
::
string
GetOpType
()
const
{
return
op_type_
;
}
std
::
vector
<
int
>
GetInputIds
()
const
{
return
input_ids_
;
}
std
::
vector
<
int
>
GetOutputIds
()
const
{
return
output_ids_
;
}
std
::
string
GetRHSType
()
const
{
return
rhs_type_
;
}
std
::
string
GetLHSType
()
const
{
return
lhs_type_
;
}
// Check whether this operation type is supported in OperationMap.
bool
IsSupport
()
const
;
std
::
string
GetExpression
(
std
::
string
dtype
,
std
::
unordered_set
<
int
>*
used
)
const
;
std
::
string
GetExpression
(
std
::
unordered_set
<
int
>*
used
)
const
;
private:
// TODO(wangchao): make offset more flexible we add stride and basic offset
std
::
string
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
*
half2fp32_statement
,
size_t
exprs_index
=
0
)
const
;
std
::
string
GetLHS
(
size_t
i
=
0
)
const
;
...
...
@@ -60,6 +72,8 @@ class OperationExpression {
std
::
string
op_type_
;
std
::
vector
<
int
>
input_ids_
;
std
::
vector
<
int
>
output_ids_
;
std
::
string
rhs_type_
;
std
::
string
lhs_type_
;
};
class
TemplateVariable
{
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
浏览文件 @
f0d193a2
...
...
@@ -288,7 +288,7 @@ void TestMain(std::string func_name,
std
::
string
dtype
)
{
fusion_group
::
OperationMap
::
Init
();
fusion_group
::
CodeGenerator
code_generator
;
std
::
string
code_str
=
code_generator
.
Generate
(
func_name
,
dtype
,
expressions
);
std
::
string
code_str
=
code_generator
.
Generate
(
func_name
,
expressions
);
VLOG
(
3
)
<<
code_str
;
LOG
(
INFO
)
<<
"dtype: "
<<
dtype
;
...
...
@@ -297,7 +297,7 @@ void TestMain(std::string func_name,
}
void
TestMain
(
fusion_group
::
SubGraph
*
subgraph
,
std
::
vector
<
int
>
input_ids
,
std
::
vector
<
int
>
output_ids
)
{
std
::
vector
<
int
>
output_ids
,
std
::
string
dtype
)
{
fusion_group
::
OperationMap
::
Init
();
fusion_group
::
CodeGenerator
code_generator
;
std
::
string
code_str
=
code_generator
.
Generate
(
subgraph
);
...
...
@@ -307,26 +307,28 @@ void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids,
std
::
vector
<
fusion_group
::
OperationExpression
>
expressions
=
code_generator
.
ConvertToExpressions
(
subgraph
);
LOG
(
INFO
)
<<
"dtype: "
<<
subgraph
->
GetDataType
();
TestElementwiseMain
(
subgraph
->
GetFuncName
(),
code_str
,
expressions
,
input_ids
,
output_ids
,
subgraph
->
GetDataType
()
);
output_ids
,
dtype
);
}
TEST
(
code_generator
,
elementwise
)
{
for
(
std
::
string
dtype
:
{
"float"
,
"float16"
})
{
// t2 = t0 * t1
// t4 = t2 + t3
// t6 = t4 - t5
// t7 = relu(t6)
// t8 = sigmoid(t7)
fusion_group
::
OperationExpression
exp1
(
"elementwise_mul"
,
{
0
,
1
},
{
2
});
fusion_group
::
OperationExpression
exp2
(
"elementwise_add"
,
{
2
,
3
},
{
4
});
fusion_group
::
OperationExpression
exp3
(
"elementwise_sub"
,
{
4
,
5
},
{
6
});
fusion_group
::
OperationExpression
exp4
(
"relu"
,
{
6
},
{
7
});
fusion_group
::
OperationExpression
exp5
(
"sigmoid"
,
{
7
},
{
8
});
fusion_group
::
OperationExpression
exp1
(
"elementwise_mul"
,
{
0
,
1
},
{
2
},
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp2
(
"elementwise_add"
,
{
2
,
3
},
{
4
},
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp3
(
"elementwise_sub"
,
{
4
,
5
},
{
6
},
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp4
(
"relu"
,
{
6
},
{
7
},
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp5
(
"sigmoid"
,
{
7
},
{
8
},
dtype
,
dtype
);
std
::
vector
<
fusion_group
::
OperationExpression
>
expressions
=
{
exp1
,
exp2
,
exp3
,
exp4
,
exp5
};
for
(
std
::
string
dtype
:
{
"float"
,
"float16"
})
{
// Expressions:
// Op(elementwise_mul), inputs:{0,1}, outputs:{2}
// Op(elementwise_add), inputs:{2,3}, outputs:{4}
...
...
@@ -340,17 +342,18 @@ TEST(code_generator, elementwise) {
}
TEST
(
code_generator
,
elementwise_grad
)
{
for
(
std
::
string
dtype
:
{
"float"
,
"float16"
})
{
// The var order: t0, t1, t2, t3, t0', t1', t2', t3'
// t2 = t0 * t1
// t3 = relu(t2)
// t2' = relu_grad(t2, t3, t3')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
fusion_group
::
OperationExpression
exp1
(
"relu_grad"
,
{
-
1
,
3
,
7
},
{
6
});
fusion_group
::
OperationExpression
exp1
(
"relu_grad"
,
{
-
1
,
3
,
7
},
{
6
},
dtype
,
dtype
);
fusion_group
::
OperationExpression
exp2
(
"elementwise_mul_grad"
,
{
0
,
1
,
2
,
6
},
{
4
,
5
}
);
{
4
,
5
},
dtype
,
dtype
);
std
::
vector
<
fusion_group
::
OperationExpression
>
expressions
=
{
exp1
,
exp2
};
for
(
std
::
string
dtype
:
{
"float"
,
"float16"
})
{
// Expressions:
// Op(relu_grad), inputs:{2,3,7}, outputs:{6}
// Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5}
...
...
@@ -474,7 +477,7 @@ TEST(code_generator, subgraph) {
// Op(elementwise_add), inputs:{7,6}, outputs:{8}
std
::
vector
<
int
>
input_ids
=
{
0
,
1
,
2
,
3
};
std
::
vector
<
int
>
output_ids
=
{
4
,
5
,
6
,
7
,
8
};
TestMain
(
&
subgraph
,
input_ids
,
output_ids
);
TestMain
(
&
subgraph
,
input_ids
,
output_ids
,
dtype
);
}
}
...
...
@@ -493,7 +496,7 @@ TEST(code_generator, subgraph_grad) {
// Op(tanh_grad), inputs:{9,4,13}, outputs:{14}
std
::
vector
<
int
>
input_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
std
::
vector
<
int
>
output_ids
=
{
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
};
TestMain
(
&
subgraph
,
input_ids
,
output_ids
);
TestMain
(
&
subgraph
,
input_ids
,
output_ids
,
dtype
);
}
}
#endif
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
浏览文件 @
f0d193a2
...
...
@@ -60,6 +60,50 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return
l
.
size
()
!=
0U
&&
r
.
size
()
!=
0U
&&
l
==
r
;
}
bool
GroupDetector
::
IsFusionGroupOp
(
const
Node
*
n
)
{
if
(
!
(
n
&&
n
->
IsOp
()
&&
n
->
Op
()))
return
false
;
bool
is_first
=
true
;
proto
::
VarType
::
Type
i_data_type
=
proto
::
VarType
::
FP32
;
proto
::
VarType
::
Type
o_data_type
=
proto
::
VarType
::
FP32
;
for
(
auto
*
i_node
:
n
->
inputs
)
{
if
(
!
i_node
->
Var
())
return
false
;
if
(
i_node
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
if
(
is_first
)
{
i_data_type
=
i_node
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
{
if
(
i_data_type
!=
i_node
->
Var
()
->
GetDataType
())
return
false
;
}
}
is_first
=
true
;
for
(
auto
*
o_node
:
n
->
outputs
)
{
if
(
!
o_node
->
Var
())
return
false
;
if
(
o_node
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
if
(
is_first
)
{
o_data_type
=
o_node
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
{
if
(
o_data_type
!=
o_node
->
Var
()
->
GetDataType
())
return
false
;
}
}
if
(
!
(
i_data_type
==
proto
::
VarType
::
FP32
||
i_data_type
==
proto
::
VarType
::
FP64
||
i_data_type
==
proto
::
VarType
::
FP16
)
||
!
(
o_data_type
==
proto
::
VarType
::
FP32
||
o_data_type
==
proto
::
VarType
::
FP64
||
o_data_type
==
proto
::
VarType
::
FP16
))
return
false
;
return
true
;
}
bool
ElementwiseGroupDetector
::
IsElementwiseOp
(
const
Node
*
n
)
{
if
(
IsSpecifiedOp
(
GetElementwiseOpTypes
(),
n
))
{
std
::
vector
<
int64_t
>
shape_0
;
...
...
@@ -85,7 +129,9 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
std
::
vector
<
std
::
vector
<
Node
*>>
ElementwiseGroupDetector
::
operator
()(
Graph
*
graph
)
{
auto
teller
=
[
&
](
const
Node
*
n
)
->
bool
{
return
IsElementwiseOp
(
n
);
};
auto
teller
=
[
&
](
const
Node
*
n
)
->
bool
{
return
IsFusionGroupOp
(
n
)
&&
IsElementwiseOp
(
n
);
};
return
SubgraphDetector
(
graph
,
teller
)();
}
...
...
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h
浏览文件 @
f0d193a2
...
...
@@ -23,7 +23,12 @@ namespace framework {
namespace
ir
{
namespace
fusion_group
{
class
ElementwiseGroupDetector
{
class
GroupDetector
{
protected:
bool
IsFusionGroupOp
(
const
Node
*
n
);
};
class
ElementwiseGroupDetector
:
GroupDetector
{
public:
std
::
vector
<
std
::
vector
<
Node
*>>
operator
()(
Graph
*
graph
);
...
...
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
浏览文件 @
f0d193a2
...
...
@@ -110,18 +110,25 @@ void FusionGroupPass::InsertFusionGroupOp(
op_desc
.
SetType
(
"fusion_group"
);
std
::
vector
<
std
::
string
>
input_names
;
std
::
vector
<
std
::
string
>
inputs_data_types
;
for
(
auto
*
n
:
input_vars_of_subgraph
)
{
input_names
.
push_back
(
n
->
Name
());
inputs_data_types
.
push_back
(
DataTypeToString
(
n
->
Var
()
->
GetDataType
()));
external_nodes
.
insert
(
n
);
}
op_desc
.
SetInput
(
"Inputs"
,
input_names
);
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
outs_data_types
;
for
(
auto
*
n
:
output_vars_of_subgraph
)
{
output_names
.
push_back
(
n
->
Name
());
outs_data_types
.
push_back
(
DataTypeToString
(
n
->
Var
()
->
GetDataType
()));
external_nodes
.
insert
(
n
);
}
op_desc
.
SetOutput
(
"Outs"
,
output_names
);
op_desc
.
SetAttr
(
"inputs_data_type"
,
inputs_data_types
);
op_desc
.
SetAttr
(
"outs_data_type"
,
outs_data_types
);
op_desc
.
SetAttr
(
"type"
,
subgraph
->
GetType
());
op_desc
.
SetAttr
(
"func_name"
,
subgraph
->
GetFuncName
());
op_desc
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
...
...
@@ -131,6 +138,7 @@ void FusionGroupPass::InsertFusionGroupOp(
for
(
auto
*
in
:
input_vars_of_subgraph
)
{
IR_NODE_LINK_TO
(
in
,
fusion_group_node
);
}
for
(
auto
*
out
:
output_vars_of_subgraph
)
{
IR_NODE_LINK_TO
(
fusion_group_node
,
out
);
}
...
...
paddle/fluid/framework/ir/fusion_group/operation.cc
浏览文件 @
f0d193a2
...
...
@@ -102,6 +102,13 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// dx = dout * (1 - out * out)
insert_handler
(
"tanh"
,
"2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0"
,
{
"${2} * (1.0 - ${1} * ${1})"
});
// cast
// out = static_cast<T>(d)
// dx = static_cast<T>(d_out)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler
(
"cast"
,
"${0}"
,
{
"${0}"
});
}
void
OperationMap
::
InsertBinaryElementwiseOperations
()
{
...
...
@@ -158,10 +165,12 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
std
::
vector
<
std
::
string
>
grad_exprs
)
{
int
type
=
0
;
int
num_oprands
=
-
1
;
// here ... represent the number of input is changed
Insert
(
type
,
num_oprands
,
op_type
,
expr
,
grad_exprs
,
{
"X"
},
{
"Out"
});
};
// here [] represent the number of input is positive(>=0).
// if input list size of Sum Op is 3, It will expand as
// ${0} + ${1} + ${2}
insert_handler
(
"sum"
,
"${0}[ + ${?}]"
,
{});
}
...
...
paddle/fluid/framework/ir/fusion_group/subgraph.h
浏览文件 @
f0d193a2
...
...
@@ -49,7 +49,6 @@ class SubGraph {
}
}
}
ExtractDataType
();
}
bool
IsValid
(
int
min_subgraph_size
)
{
...
...
@@ -61,11 +60,10 @@ class SubGraph {
return
false
;
}
return
ExtractDataType
()
;
return
true
;
}
int
GetType
()
const
{
return
type_
;
}
std
::
string
GetDataType
()
const
{
return
data_type_
;
}
void
SetFuncName
(
std
::
string
func_name
)
{
func_name_
=
func_name
;
}
std
::
string
GetFuncName
()
const
{
return
func_name_
;
}
...
...
@@ -162,37 +160,6 @@ class SubGraph {
}
private:
bool
ExtractDataType
()
{
bool
is_first
=
true
;
proto
::
VarType
::
Type
data_type
=
proto
::
VarType
::
FP32
;
for
(
auto
*
n
:
nodes_set_
)
{
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
if
(
n
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
// All var node in a subgraph should hold a LoDTensor.
return
false
;
}
if
(
is_first
)
{
data_type
=
n
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
if
(
n
->
Var
()
->
GetDataType
()
!=
data_type
)
{
// DataType of VarDesc in a subgraph is not the same.
return
false
;
}
}
}
if
(
data_type
==
proto
::
VarType
::
FP32
)
{
data_type_
=
"float"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP64
)
{
data_type_
=
"double"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP16
)
{
data_type_
=
"float16"
;
}
else
{
VLOG
(
2
)
<<
"Only support fp32, fp64 and fp16 in fusion_group."
;
return
false
;
}
return
true
;
}
void
TopologicalSort
()
{
if
(
!
is_sorted_
)
{
std
::
unordered_map
<
Node
*
,
std
::
vector
<
Node
*>>
inputs_map
;
...
...
paddle/fluid/operators/fused/fusion_group_op.cc
浏览文件 @
f0d193a2
...
...
@@ -21,7 +21,7 @@ class FusionGroupOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
const
size_t
num_ins
=
ctx
->
Inputs
(
"Inputs"
).
size
();
const
size_t
num_outs
=
ctx
->
Outputs
(
"Outs"
).
size
();
...
...
@@ -58,6 +58,13 @@ class FusionGroupOp : public framework::OperatorWithKernel {
ctx
->
ShareLoD
(
"Inputs"
,
/*->*/
"Outs"
,
0
,
j
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
proto
::
VarType
::
FP32
,
platform
::
CUDAPlace
(
0
));
};
};
class
FusionGroupOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -69,6 +76,12 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Outs"
,
"(std::vector<LoDTensor>) The outputs of fusion_group op."
)
.
AsDuplicable
();
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"outs_data_type"
,
"The data type of Outputs in fusion_group op."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"inputs_data_type"
,
"The data type of Inputs in fusion_group op."
)
.
SetDefault
({});
AddAttr
<
int
>
(
"type"
,
"Fusion type."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"func_name"
,
"Name of the generated functions."
)
.
SetDefault
(
""
);
...
...
paddle/fluid/operators/fused/fusion_group_op.h
浏览文件 @
f0d193a2
...
...
@@ -22,6 +22,20 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
static
void
MutableMultiTypeData
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>*
var
,
const
std
::
vector
<
std
::
string
>&
data_type
,
const
platform
::
Place
&
place
)
{
for
(
size_t
i
=
0
;
i
<
(
*
var
).
size
();
i
++
)
{
if
(
data_type
[
i
]
==
"float"
)
{
(
*
var
)[
i
]
->
mutable_data
<
float
>
(
place
);
}
else
if
(
data_type
[
i
]
==
"double"
)
{
(
*
var
)[
i
]
->
mutable_data
<
double
>
(
place
);
}
else
if
(
data_type
[
i
]
==
"::paddle::platform::float16"
)
{
(
*
var
)[
i
]
->
mutable_data
<
paddle
::
platform
::
float16
>
(
place
);
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
FusionGroupKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -29,14 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> {
auto
ins
=
ctx
.
MultiInput
<
framework
::
LoDTensor
>
(
"Inputs"
);
auto
outs
=
ctx
.
MultiOutput
<
framework
::
LoDTensor
>
(
"Outs"
);
int
type
=
ctx
.
Attr
<
int
>
(
"type"
);
auto
outs_type
=
ctx
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"outs_data_type"
);
auto
inputs_type
=
ctx
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"inputs_data_type"
);
size_t
num_ins
=
ins
.
size
();
size_t
num_outs
=
outs
.
size
();
auto
place
=
ctx
.
GetPlace
();
for
(
size_t
i
=
0
;
i
<
num_outs
;
++
i
)
{
outs
[
i
]
->
mutable_data
<
T
>
(
place
);
}
MutableMultiTypeData
(
&
outs
,
outs_type
,
place
);
std
::
string
func_name
=
ctx
.
Attr
<
std
::
string
>
(
"func_name"
);
platform
::
DeviceCode
*
dev_code
=
...
...
@@ -47,13 +62,25 @@ class FusionGroupKernel : public framework::OpKernel<T> {
size_t
n
=
ins
[
0
]
->
numel
();
std
::
vector
<
void
*>
args
;
args
.
push_back
(
&
n
);
std
::
vector
<
const
T
*>
ptrs
(
num_ins
+
num_outs
);
std
::
vector
<
const
void
*>
ptrs
(
num_ins
+
num_outs
);
for
(
size_t
i
=
0
;
i
<
num_ins
;
++
i
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
T
>
();
if
(
inputs_type
[
i
]
==
"::paddle::platform::float16"
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
paddle
::
platform
::
float16
>
();
}
else
if
(
inputs_type
[
i
]
==
"double"
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
double
>
();
}
else
if
(
inputs_type
[
i
]
==
"float"
)
{
ptrs
[
i
]
=
ins
[
i
]
->
data
<
float
>
();
}
args
.
push_back
(
&
ptrs
[
i
]);
}
for
(
size_t
j
=
0
;
j
<
num_outs
;
++
j
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
T
>
();
if
(
outs_type
[
j
]
==
"::paddle::platform::float16"
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
paddle
::
platform
::
float16
>
();
}
else
if
(
outs_type
[
j
]
==
"double"
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
double
>
();
}
else
if
(
outs_type
[
j
]
==
"float"
)
{
ptrs
[
num_ins
+
j
]
=
outs
[
j
]
->
data
<
float
>
();
}
args
.
push_back
(
&
ptrs
[
num_ins
+
j
]);
}
dev_code
->
Launch
(
n
,
&
args
);
...
...
paddle/fluid/operators/fused/fusion_group_op_test.cc
浏览文件 @
f0d193a2
...
...
@@ -57,7 +57,8 @@ framework::OpDesc* CreateFusionGroupOp(
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
std
::
string
func_name
)
{
const
std
::
vector
<
std
::
string
>&
inputs_data_type
,
const
std
::
vector
<
std
::
string
>&
outs_data_type
,
std
::
string
func_name
)
{
EXPECT_EQ
(
input_names
.
size
(),
input_shapes
.
size
());
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
...
...
@@ -76,6 +77,8 @@ framework::OpDesc* CreateFusionGroupOp(
op
->
SetType
(
"fusion_group"
);
op
->
SetInput
(
"Inputs"
,
input_names
);
op
->
SetOutput
(
"Outs"
,
output_names
);
op
->
SetAttr
(
"inputs_data_type"
,
inputs_data_type
);
op
->
SetAttr
(
"outs_data_type"
,
outs_data_type
);
op
->
SetAttr
(
"type"
,
type
);
op
->
SetAttr
(
"func_name"
,
func_name
);
op
->
SetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
...
...
@@ -130,6 +133,8 @@ void CheckOutputs(framework::Scope* scope,
void
TestMain
(
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
input_shapes
,
const
std
::
vector
<
std
::
string
>&
output_names
,
int
type
,
const
std
::
vector
<
std
::
string
>&
inputs_data_type
,
const
std
::
vector
<
std
::
string
>&
outs_data_type
,
std
::
string
func_name
,
std
::
string
cuda_kernel_str
,
CPUKernelFunc
cpu_kernel_func
)
{
// Compile the device code
...
...
@@ -139,8 +144,9 @@ void TestMain(const std::vector<std::string>& input_names,
// Create a ProgramDesc that has a fusion_group_op.
framework
::
ProgramDesc
program
;
framework
::
OpDesc
*
op_desc
=
CreateFusionGroupOp
(
&
program
,
input_names
,
input_shapes
,
output_names
,
type
,
func_name
);
framework
::
OpDesc
*
op_desc
=
CreateFusionGroupOp
(
&
program
,
input_names
,
input_shapes
,
output_names
,
type
,
inputs_data_type
,
outs_data_type
,
func_name
);
auto
fusion_group_op
=
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
framework
::
Scope
scope
;
...
...
@@ -210,8 +216,11 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
}
};
TestMain
(
input_names
,
input_shapes
,
output_names
,
0
,
"elementwise_cuda_kernel_0"
,
kernel
,
elementwise_cpu_kernel_0
);
std
::
vector
<
std
::
string
>
inputs_data_type
(
input_names
.
size
(),
"float"
);
std
::
vector
<
std
::
string
>
outs_data_type
(
output_names
.
size
(),
"float"
);
TestMain
(
input_names
,
input_shapes
,
output_names
,
0
,
inputs_data_type
,
outs_data_type
,
"elementwise_cuda_kernel_0"
,
kernel
,
elementwise_cpu_kernel_0
);
}
}
// namespace operators
...
...
python/paddle/fluid/tests/unittests/ir/pass_test.py
浏览文件 @
f0d193a2
...
...
@@ -142,8 +142,8 @@ class PassTest(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
outs_opt
[
i
],
outs
[
i
],
atol
=
atol
),
"Output < {} > has diff at {}
"
.
format
(
self
.
fetch_list
[
i
].
name
,
str
(
place
)
))
"Output < {} > has diff at {}
, expected {} but got {}"
.
format
(
self
.
fetch_list
[
i
].
name
,
str
(
place
),
outs_opt
[
i
],
outs
[
i
]
))
def
_check_fused_ops
(
self
,
program
):
'''
...
...
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
浏览文件 @
f0d193a2
...
...
@@ -125,17 +125,15 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
fluid
.
data
(
name
=
"data2"
,
shape
=
[
128
,
128
],
dtype
=
dtype
))
# subgraph with only 1 op node
tmp_0
=
self
.
feed_vars
[
0
]
*
self
.
feed_vars
[
1
]
tmp_1
=
layers
.
mul
(
tmp_0
,
self
.
feed_vars
[
2
])
tmp_2
=
layers
.
cast
(
tmp_0
,
dtype
=
"float16"
)
tmp_3
=
layers
.
cast
(
tmp_1
,
dtype
=
"float16"
)
# subgraph with 2 op nodes
tmp_2
=
layers
.
cast
(
tmp_0
,
dtype
=
"float16"
)
tmp_4
=
layers
.
relu
(
tmp_2
+
tmp_3
)
tmp_5
=
layers
.
cast
(
tmp_4
,
dtype
=
dtype
)
self
.
fetch_list
=
[
tmp_5
]
self
.
num_fused_ops
=
1
self
.
fetch_list
=
[
tmp_
0
,
tmp_1
,
tmp_2
,
tmp_3
,
tmp_4
,
tmp_
5
]
self
.
num_fused_ops
=
2
class
FusionGroupPassSumTest
(
FusionGroupPassTest
):
...
...
@@ -147,9 +145,28 @@ class FusionGroupPassSumTest(FusionGroupPassTest):
tmp_1
=
layers
.
sum
([
tmp_0
,
self
.
feed_vars
[
2
],
self
.
feed_vars
[
3
]])
tmp_2
=
layers
.
sum
([
tmp_1
,
self
.
feed_vars
[
4
]])
self
.
fetch_list
=
[
tmp_0
,
tmp_1
]
self
.
fetch_list
=
[
tmp_0
,
tmp_1
,
tmp_2
]
self
.
num_fused_ops
=
1
class
FusionGroupPassCastTest
(
FusionGroupPassTest
):
def
build_program
(
self
,
dtype
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
2
,
2
],
dtype
,
2
)
tmp_0
=
layers
.
elementwise_add
(
self
.
feed_vars
[
0
],
self
.
feed_vars
[
1
])
tmp_1
=
layers
.
cast
(
tmp_0
,
dtype
=
"double"
)
tmp_2
=
layers
.
cast
(
tmp_1
,
dtype
=
"float32"
)
self
.
fetch_list
=
[
tmp_0
,
tmp_1
,
tmp_2
]
self
.
num_fused_ops
=
1
def
setUp
(
self
):
self
.
build_program
(
"float64"
)
self
.
feeds
=
self
.
_feed_random_data
(
self
.
feed_vars
)
self
.
pass_names
=
"fusion_group_pass"
self
.
fused_op_type
=
"fusion_group"
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录