Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
79ed7177
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
79ed7177
编写于
5月 21, 2021
作者:
王
王明冬
提交者:
GitHub
5月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add method for enhance pass,test=develop (#33004)
上级
7be6191b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
607 addition
and
0 deletion
+607
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
+178
-0
paddle/fluid/framework/ir/op_compat_sensible_pass.h
paddle/fluid/framework/ir/op_compat_sensible_pass.h
+294
-0
paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc
paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc
+133
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
79ed7177
...
...
@@ -50,6 +50,7 @@ if (WITH_TESTING)
endif
(
WITH_TESTING
)
cc_library
(
graph_pattern_detector SRCS graph_pattern_detector.cc DEPS
${
GRAPH_PATTERN_DETECTOR_DEPS
}
)
cc_library
(
op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector
)
cc_library
(
subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor
)
cc_library
(
fuse_pass_base SRCS fuse_pass_base.cc DEPS pass
)
cc_library
(
placement_pass_base SRCS placement_pass_base.cc DEPS pass
)
...
...
@@ -139,6 +140,7 @@ cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass
)
cc_test
(
test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector
)
cc_test
(
test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass
)
cc_test
(
test_fc_fuse_pass_cc SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto
)
cc_test
(
test_fc_lstm_fuse_pass_cc SRCS fc_lstm_fuse_pass_tester.cc DEPS fc_lstm_fuse_pass framework_proto
)
cc_test
(
test_fc_gru_fuse_pass_cc SRCS fc_gru_fuse_pass_tester.cc DEPS fc_gru_fuse_pass framework_proto
)
...
...
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
0 → 100644
浏览文件 @
79ed7177
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
AttrCompat
&
AttrCompat
::
IsStringIn
(
const
std
::
set
<
std
::
string
>&
candidates
)
{
conditions_
.
emplace_back
([
candidates
](
const
Attribute
&
attr
)
->
bool
{
std
::
string
value
=
BOOST_GET_CONST
(
std
::
string
,
attr
);
for
(
auto
&
str
:
candidates
)
{
if
(
str
==
value
)
{
return
true
;
}
}
return
false
;
});
return
*
this
;
}
AttrCompat
&
AttrCompat
::
IsStringMatch
(
const
std
::
function
<
bool
(
const
std
::
string
&
)
>&
func
)
{
conditions_
.
emplace_back
([
func
](
const
Attribute
&
attr
)
->
bool
{
std
::
string
value
=
BOOST_GET_CONST
(
std
::
string
,
attr
);
return
func
(
value
);
});
return
*
this
;
}
AttrCompat
&
AttrCompat
::
IsIntIn
(
const
std
::
set
<
int
>&
candidates
)
{
conditions_
.
emplace_back
([
candidates
](
const
Attribute
&
attr
)
->
bool
{
int
value
=
BOOST_GET_CONST
(
int
,
attr
);
return
candidates
.
find
(
value
)
!=
candidates
.
end
();
});
return
*
this
;
}
//! Todo: append the definition.
AttrCompat
&
AttrCompat
::
IsLeftDefault
()
{
return
*
this
;
}
bool
AttrCompat
::
operator
()(
const
OpDesc
&
op_desc
)
{
if
(
!
op_desc
.
HasAttr
(
attr_name_
))
{
return
false
;
}
const
Attribute
attr
=
op_desc
.
GetAttr
(
attr_name_
);
for
(
auto
&
func
:
conditions_
)
{
if
(
!
func
(
attr
))
{
return
false
;
}
}
return
true
;
}
AttrCompat
&
AttrCompat
::
IsBoolEQ
(
bool
v
)
{
conditions_
.
emplace_back
([
v
](
const
Attribute
&
attr
)
->
bool
{
bool
value
=
BOOST_GET_CONST
(
bool
,
attr
);
return
value
==
v
;
});
return
*
this
;
}
InputOrOutputCompat
&
InputOrOutputCompat
::
IsTensor
()
{
conditions_
.
emplace_back
([](
const
std
::
vector
<
std
::
string
>&
input
)
->
bool
{
return
input
.
size
()
==
1u
;
});
return
*
this
;
}
InputOrOutputCompat
&
InputOrOutputCompat
::
IsOptional
()
{
optional_
=
true
;
return
*
this
;
}
bool
InputOrOutputCompat
::
operator
()(
const
std
::
vector
<
std
::
string
>&
input
)
const
{
if
(
input
.
empty
())
return
false
;
for
(
auto
&
func
:
conditions_
)
{
if
(
!
func
(
input
))
{
return
false
;
}
}
return
true
;
}
AttrCompat
&
OpCompat
::
AddAttr
(
const
std
::
string
&
attr_name
)
{
attr_compats_
.
emplace_back
(
attr_name
,
this
);
return
attr_compats_
.
back
();
}
InputOrOutputCompat
&
OpCompat
::
AddInput
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_EQ
(
input_compats_
.
find
(
name
),
input_compats_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"The input with the same name has been added"
));
input_compats_
.
emplace
(
name
,
InputOrOutputCompat
(
name
,
this
));
return
input_compats_
.
at
(
name
);
}
InputOrOutputCompat
&
OpCompat
::
AddOutput
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_EQ
(
output_compats_
.
find
(
name
),
output_compats_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"The output with the same name has been added"
));
output_compats_
.
emplace
(
name
,
InputOrOutputCompat
(
name
,
this
));
return
output_compats_
.
at
(
name
);
}
bool
OpCompat
::
Judge
(
const
OpDesc
&
op_desc
)
{
for
(
auto
&
attr_compat
:
attr_compats_
)
{
if
(
!
attr_compat
(
op_desc
))
{
return
false
;
}
}
const
VariableNameMap
&
inputs_map
=
op_desc
.
Inputs
();
for
(
auto
&
input_desc
:
inputs_map
)
{
if
(
input_compats_
.
find
(
input_desc
.
first
)
==
input_compats_
.
end
())
{
if
(
!
input_desc
.
second
.
empty
())
{
return
false
;
}
}
}
for
(
auto
&
input_val
:
input_compats_
)
{
if
(
inputs_map
.
find
(
input_val
.
first
)
==
inputs_map
.
end
())
{
if
(
!
input_val
.
second
.
Optional
())
{
return
false
;
}
}
else
{
if
(
!
input_val
.
second
(
inputs_map
.
at
(
input_val
.
first
)))
{
return
false
;
}
}
}
const
VariableNameMap
&
outputs_map
=
op_desc
.
Outputs
();
for
(
auto
&
output_desc
:
outputs_map
)
{
if
(
output_compats_
.
find
(
output_desc
.
first
)
==
output_compats_
.
end
())
{
if
(
!
output_desc
.
second
.
empty
())
{
return
false
;
}
}
}
for
(
auto
&
output_val
:
output_compats_
)
{
if
(
outputs_map
.
find
(
output_val
.
first
)
==
outputs_map
.
end
())
{
if
(
!
output_val
.
second
.
Optional
())
{
return
false
;
}
}
else
{
if
(
!
output_val
.
second
(
outputs_map
.
at
(
output_val
.
first
)))
{
return
false
;
}
}
}
return
true
;
}
OpCompat
&
OpCompatSensiblePass
::
AddOpCompat
(
OpCompat
&&
op_compat
)
{
std
::
string
name
=
op_compat
.
Name
();
op_compat_judgers_
[
name
].
reset
(
new
OpCompat
(
std
::
move
(
op_compat
)));
return
*
(
op_compat_judgers_
[
name
]);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/op_compat_sensible_pass.h
0 → 100644
浏览文件 @
79ed7177
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
OpCompat
;
class
AttrCompat
{
public:
AttrCompat
(
const
std
::
string
&
attr_name
,
OpCompat
*
op_compat
)
:
attr_name_
(
attr_name
),
op_compat_
(
op_compat
)
{}
// @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain.
AttrCompat
&
IsStringIn
(
const
std
::
set
<
std
::
string
>&
candidates
);
//! Assert the attribute is a string and match a custom judging function.
AttrCompat
&
IsStringMatch
(
const
std
::
function
<
bool
(
const
std
::
string
&
)
>&
func
);
// @}
//! Assert the attribute is an integer in the `candidates` domain.
AttrCompat
&
IsIntIn
(
const
std
::
set
<
int
>&
candidates
);
// @{ Number-releated methods
//! Assert the attribute is a number and > `v`.
template
<
typename
T
>
AttrCompat
&
IsNumGT
(
T
v
);
//! Assert the attribute is a number and >= `v`.
template
<
typename
T
>
AttrCompat
&
IsNumGE
(
T
v
);
//! Assert the attribute is a number and < `v`.
template
<
typename
T
>
AttrCompat
&
IsNumLT
(
T
v
);
//! Assert the attribute is a number and <= `v`.
template
<
typename
T
>
AttrCompat
&
IsNumLE
(
T
v
);
//! Assert the attribute is a number and == `v`.
template
<
typename
T
>
AttrCompat
&
IsNumEQ
(
T
v
);
//! Assert the attribute is a number and matches a customized judging
//! function.
template
<
typename
T
>
AttrCompat
&
IsNumMatch
(
bool
(
*
func
)(
T
));
// @}
//! Assert the attribute is a boolean value equals `v`.
AttrCompat
&
IsBoolEQ
(
bool
v
);
//! Tell whether this attribute is left as default value.
AttrCompat
&
IsLeftDefault
();
//! Jump back to retrieve OpCompat instance.
OpCompat
&
End
()
{
return
*
op_compat_
;
}
bool
operator
()(
const
OpDesc
&
op_desc
);
private:
std
::
string
attr_name_
;
OpCompat
*
op_compat_
;
std
::
vector
<
std
::
function
<
bool
(
const
Attribute
&
)
>>
conditions_
;
};
class
InputOrOutputCompat
{
public:
InputOrOutputCompat
(
const
std
::
string
&
name
,
OpCompat
*
op_compat
)
:
optional_
(
false
),
name_
(
name
),
op_compat_
(
op_compat
)
{}
InputOrOutputCompat
&
IsTensor
();
InputOrOutputCompat
&
IsOptional
();
bool
Optional
()
const
{
return
optional_
;
}
bool
operator
()(
const
std
::
vector
<
std
::
string
>&
input
)
const
;
//! Jump back to retrieve OpCompat instance.
OpCompat
&
End
()
{
return
*
op_compat_
;
}
private:
bool
optional_
;
std
::
string
name_
;
OpCompat
*
op_compat_
;
std
::
vector
<
std
::
function
<
bool
(
const
std
::
vector
<
std
::
string
>&
)
>>
conditions_
;
};
/**
* OpCompat is a helper class to help define the compatible Op definition.
*
* Usage:
* OpCompat compat("FC");
* compat.AddAttr("in_num_col_dims").IsNumLE(1).End()
* .AddAttr("activation_type").IsStringIn({"tanh", "sigmoid"}).End()
* .AddInput("Input").IsTensor().End()
* .AddInput("W").IsTensor().End()
* .AddInput("Bias").IsTensor().IsOptional().End()
* .AddOutput("Out").IsTensor().End()
*
* All the inference-aware Op defition is as above, all the other attributes not
* contained in the definition should be set default value or it would be judged
* incompatible.
*/
class
OpCompat
{
public:
explicit
OpCompat
(
const
std
::
string
&
op_name
)
:
op_name_
(
op_name
)
{}
explicit
OpCompat
(
std
::
string
&&
op_name
)
:
op_name_
(
std
::
move
(
op_name
))
{}
explicit
OpCompat
(
const
OpCompat
&
)
=
default
;
explicit
OpCompat
(
OpCompat
&&
)
=
default
;
AttrCompat
&
AddAttr
(
const
std
::
string
&
attr_name
);
InputOrOutputCompat
&
AddInput
(
const
std
::
string
&
name
);
InputOrOutputCompat
&
AddOutput
(
const
std
::
string
&
name
);
//! Judge whether an OpDesc match the defined Op compatibility.
bool
Judge
(
const
OpDesc
&
op_desc
);
const
std
::
string
&
Name
()
const
{
return
op_name_
;
}
private:
std
::
string
op_name_
;
std
::
vector
<
AttrCompat
>
attr_compats_
;
std
::
unordered_map
<
std
::
string
,
InputOrOutputCompat
>
input_compats_
;
std
::
unordered_map
<
std
::
string
,
InputOrOutputCompat
>
output_compats_
;
};
/**
* OpCompatSensiblePass is a base class for all the passes thouse is sensitive
* to Op update.
* There are two methods to help tell the compability of an Op
* bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, Graph* g);
* bool IsCompat(const OpDesc& op_desc);
*
* One can register the related Op compabilities using
* void AddOpCompat(OpCompat&& judger);
*
* Most of the Passes are used for fusing ops, so we define a method for such
* scenerios.
* void AccessSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g);
* It will check the Op compatibility automatically.
* For other scenirios, one should call `IsCompat` by himself.
*
* A FC fuse pass example:
* class FcFusePass : public OpCompatSensiblePass {
* public:
* FcFusePass() {
* // define Mul op compatiblity.
* AddOpCompat(OpCompat("Mul"))
* .AddInput("Input").IsTensor().End()
* .AddAttr("in_num_col_dims").IsNumGE(1);
* AddOpCompat(OpCompat("Add")). ...;
* // There are multiple activation implemention.
* AddOpCompat(OpCompat("Tanh")). ...;
* AddOpCompat(OpCompat("Sigmoid")). ...;
* }
*
* // override the subgraph access method
* virtual bool AccessSubgraphImpl(
* const GraphPatternDetector::subgraph_t& subgraph,
* Graph* g) override { ... }
*
* // Call the AccessSubgraph method in main procedure of this Pass.
* };
*/
class
OpCompatSensiblePass
:
public
Pass
{
public:
//! Access the subgraph and pattern.
void
AccessSubgraph
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
if
(
IsCompat
(
subgraph
,
g
))
{
AccessSubgraphImpl
(
subgraph
,
g
);
}
}
protected:
/**
* Developer should push the compatibility `teller` for each kind of Op in the
* subgraph.
* NOTE One should add all the related op compatiblity in the construct so
* that all the following methods are valid.
*/
OpCompat
&
AddOpCompat
(
OpCompat
&&
op_compat
);
//! Modify the subgraph.
virtual
bool
AccessSubgraphImpl
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
const
{
return
true
;
}
//! Tell the Op compability of a subgraph.
bool
IsCompat
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
const
{
CHECK
(
!
op_compat_judgers_
.
empty
())
<<
"At least one OpCompat instance should be added in the "
"OpCompatSensiblePass."
;
// Check the all the ops in the subgraph are contained in the
// op_compat.
for
(
auto
&
node_pair
:
subgraph
)
{
if
(
!
node_pair
.
first
->
IsOp
())
continue
;
auto
op_type
=
node_pair
.
second
->
Op
()
->
Type
();
if
(
!
op_compat_judgers_
.
count
(
op_type
))
{
return
false
;
}
auto
&
judger
=
*
op_compat_judgers_
.
at
(
op_type
);
if
(
!
judger
.
Judge
(
*
(
node_pair
.
second
->
Op
())))
{
return
false
;
}
}
return
true
;
}
//! Tell the op compatibility of a single Op.
bool
IsCompat
(
const
OpDesc
&
op_desc
)
const
{
if
(
!
op_compat_judgers_
.
count
(
op_desc
.
Type
()))
return
false
;
return
op_compat_judgers_
.
at
(
op_desc
.
Type
())
->
Judge
(
op_desc
);
}
private:
std
::
map
<
std
::
string
,
std
::
unique_ptr
<
OpCompat
>>
op_compat_judgers_
;
};
template
<
typename
T
>
AttrCompat
&
AttrCompat
::
IsNumGT
(
T
v
)
{
conditions_
.
emplace_back
([
v
](
const
Attribute
&
attr
)
->
bool
{
T
value
=
BOOST_GET_CONST
(
T
,
attr
);
return
value
>
v
;
});
return
*
this
;
}
template
<
typename
T
>
AttrCompat
&
AttrCompat
::
IsNumGE
(
T
v
)
{
conditions_
.
emplace_back
([
v
](
const
Attribute
&
attr
)
->
bool
{
T
value
=
BOOST_GET_CONST
(
T
,
attr
);
return
value
>=
v
;
});
return
*
this
;
}
template
<
typename
T
>
AttrCompat
&
AttrCompat
::
IsNumLT
(
T
v
)
{
conditions_
.
emplace_back
([
v
](
const
Attribute
&
attr
)
->
bool
{
T
value
=
BOOST_GET_CONST
(
T
,
attr
);
return
value
<
v
;
});
return
*
this
;
}
template
<
typename
T
>
AttrCompat
&
AttrCompat
::
IsNumLE
(
T
v
)
{
conditions_
.
emplace_back
([
v
](
const
Attribute
&
attr
)
->
bool
{
T
value
=
BOOST_GET_CONST
(
T
,
attr
);
return
value
<=
v
;
});
return
*
this
;
}
template
<
typename
T
>
AttrCompat
&
AttrCompat
::
IsNumEQ
(
T
v
)
{
conditions_
.
emplace_back
([
v
](
const
Attribute
&
attr
)
->
bool
{
T
value
=
BOOST_GET_CONST
(
T
,
attr
);
return
value
==
v
;
});
return
*
this
;
}
template
<
typename
T
>
AttrCompat
&
AttrCompat
::
IsNumMatch
(
bool
(
*
func
)(
T
))
{
conditions_
.
emplace_back
([
func
](
const
Attribute
&
attr
)
->
bool
{
T
value
=
BOOST_GET_CONST
(
T
,
attr
);
return
func
(
value
);
});
return
*
this
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc
0 → 100644
浏览文件 @
79ed7177
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
TEST
(
OpCompatSensiblePass
,
compatOp
)
{
auto
lambda
=
[](
const
std
::
string
&
str
)
{
return
str
==
"tanh"
;
};
OpCompat
compat
(
"FC"
);
compat
.
AddAttr
(
"in_num_col_dims"
)
.
IsIntIn
({
1
,
2
})
.
IsNumLE
(
1
)
.
IsLeftDefault
()
.
End
()
.
AddAttr
(
"activation_type"
)
.
IsStringIn
({
"tanh"
,
"sigmoid"
})
.
IsStringMatch
(
lambda
)
.
End
()
.
AddAttr
(
"test_attr"
)
.
IsBoolEQ
(
true
)
.
End
()
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"W"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"Test"
)
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
();
OpDesc
fc_op
;
std
::
unordered_map
<
std
::
string
,
Attribute
>
attr_map
;
attr_map
[
"in_num_col_dims"
]
=
1
;
attr_map
[
"activation_type"
]
=
std
::
string
(
"tanh"
);
attr_map
[
"test_attr"
]
=
true
;
fc_op
.
SetAttrMap
(
attr_map
);
fc_op
.
SetInput
(
"Input"
,
std
::
vector
<
std
::
string
>
{
"test_input"
});
fc_op
.
SetInput
(
"W"
,
std
::
vector
<
std
::
string
>
{
"test_input_0"
});
fc_op
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
{
"test_input_1"
});
fc_op
.
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
{
"test_output"
});
EXPECT_STREQ
(
compat
.
Name
().
c_str
(),
"FC"
);
EXPECT_TRUE
(
compat
.
Judge
(
fc_op
));
}
class
OpCompatSensiblePassTest
:
public
OpCompatSensiblePass
{
public:
OpCompatSensiblePassTest
();
bool
TestIsCompat
(
const
OpDesc
&
op_desc
)
{
return
IsCompat
(
op_desc
);
}
};
OpCompatSensiblePassTest
::
OpCompatSensiblePassTest
()
{
AddOpCompat
(
OpCompat
(
"FC"
))
.
AddAttr
(
"in_num_col_dims"
)
.
IsNumLE
(
1
)
.
End
()
.
AddAttr
(
"activation_type"
)
.
IsStringIn
({
"tanh"
,
"sigmoid"
})
.
End
()
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"W"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
();
}
TEST
(
OpCompatSensiblePass
,
IsCompat
)
{
OpCompatSensiblePassTest
test
;
OpDesc
fc_op
;
fc_op
.
SetType
(
"FC"
);
std
::
unordered_map
<
std
::
string
,
Attribute
>
attr_map
;
attr_map
[
"in_num_col_dims"
]
=
1
;
attr_map
[
"activation_type"
]
=
std
::
string
(
"tanh"
);
fc_op
.
SetAttrMap
(
attr_map
);
fc_op
.
SetInput
(
"Input"
,
std
::
vector
<
std
::
string
>
{
"test_input"
});
fc_op
.
SetInput
(
"W"
,
std
::
vector
<
std
::
string
>
{
"test_input_0"
});
fc_op
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
{
"test_input_1"
});
fc_op
.
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
{
"test_output"
});
EXPECT_TRUE
(
test
.
TestIsCompat
(
fc_op
));
ProgramDesc
prog
;
std
::
unique_ptr
<
Graph
>
g
(
new
Graph
(
prog
));
Node
*
o1
=
g
->
CreateOpNode
(
&
fc_op
);
GraphPatternDetector
detector
;
PDNode
*
op2
=
detector
.
mutable_pattern
()
->
NewNode
([](
Node
*
x
)
{
return
true
;
});
GraphPatternDetector
::
subgraph_t
subgraph
;
subgraph
[
op2
]
=
o1
;
test
.
AccessSubgraph
(
subgraph
,
g
.
get
());
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录