Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
dc72ffa5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dc72ffa5
编写于
5月 25, 2021
作者:
王
王明冬
提交者:
GitHub
5月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add the IsLeftDefault definition for pass enhance,test=develop (#33081)
上级
88dfb30f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
138 addition
and
44 deletion
+138
-44
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-1
paddle/fluid/framework/ir/fuse_pass_base.h
paddle/fluid/framework/ir/fuse_pass_base.h
+2
-2
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
+59
-6
paddle/fluid/framework/ir/op_compat_sensible_pass.h
paddle/fluid/framework/ir/op_compat_sensible_pass.h
+6
-18
paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc
paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc
+70
-17
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
dc72ffa5
...
@@ -52,7 +52,7 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PA
...
@@ -52,7 +52,7 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PA
cc_library
(
op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector
)
cc_library
(
op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector
)
cc_library
(
subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor
)
cc_library
(
subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor
)
cc_library
(
fuse_pass_base SRCS fuse_pass_base.cc DEPS pass
)
cc_library
(
fuse_pass_base SRCS fuse_pass_base.cc DEPS
op_compat_sensible_
pass
)
cc_library
(
placement_pass_base SRCS placement_pass_base.cc DEPS pass
)
cc_library
(
placement_pass_base SRCS placement_pass_base.cc DEPS pass
)
cc_library
(
coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper
)
cc_library
(
coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper
)
...
...
paddle/fluid/framework/ir/fuse_pass_base.h
浏览文件 @
dc72ffa5
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/
op_compat_sensible_
pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -46,7 +46,7 @@ enum FuseOptions {
...
@@ -46,7 +46,7 @@ enum FuseOptions {
FUSE_MKLDNN
// fusing will be done with MKL-DNN
FUSE_MKLDNN
// fusing will be done with MKL-DNN
};
};
class
FusePassBase
:
public
Pass
{
class
FusePassBase
:
public
OpCompatSensible
Pass
{
public:
public:
void
Init
(
const
std
::
string
&
repr
,
Graph
*
graph
)
const
;
void
Init
(
const
std
::
string
&
repr
,
Graph
*
graph
)
const
;
Scope
*
param_scope
()
const
;
Scope
*
param_scope
()
const
;
...
...
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
浏览文件 @
dc72ffa5
...
@@ -15,7 +15,7 @@ limitations under the License. */
...
@@ -15,7 +15,7 @@ limitations under the License. */
#include <memory>
#include <memory>
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "paddle/fluid/framework/op_info.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
...
@@ -51,11 +51,33 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
...
@@ -51,11 +51,33 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
}
}
//! Todo: append the definition.
//! Todo: append the definition.
AttrCompat
&
AttrCompat
::
IsLeftDefault
()
{
return
*
this
;
}
AttrCompat
&
AttrCompat
::
IsLeftDefault
()
{
const
std
::
string
&
op_name
=
op_compat_
->
Name
();
if
(
!
OpInfoMap
::
Instance
().
Has
(
op_name
))
{
VLOG
(
3
)
<<
"Op ("
<<
op_name
<<
") is not registered!"
;
conditions_
.
emplace_back
([](
const
Attribute
&
attr
)
{
return
false
;
});
return
*
this
;
}
const
OpInfo
&
op_info
=
OpInfoMap
::
Instance
().
Get
(
op_name
);
const
AttributeMap
attrs
=
op_info
.
Checker
()
->
GetAttrsDefaultValuesMap
();
if
(
attrs
.
find
(
attr_name_
)
==
attrs
.
end
())
{
VLOG
(
3
)
<<
"Op ("
<<
op_name
<<
") has no default attr:"
<<
attr_name_
;
conditions_
.
emplace_back
([](
const
Attribute
&
attr
)
{
return
false
;
});
}
else
{
Attribute
default_attr
=
attrs
.
at
(
attr_name_
);
conditions_
.
emplace_back
([
default_attr
](
const
Attribute
&
attr
)
->
bool
{
return
attr
==
default_attr
;
});
}
return
*
this
;
}
bool
AttrCompat
::
operator
()(
const
OpDesc
&
op_desc
)
{
bool
AttrCompat
::
operator
()(
const
OpDesc
&
op_desc
)
{
if
(
conditions_
.
empty
())
{
return
true
;
}
if
(
!
op_desc
.
HasAttr
(
attr_name_
))
{
if
(
!
op_desc
.
HasAttr
(
attr_name_
))
{
return
false
;
return
optional_
;
}
}
const
Attribute
attr
=
op_desc
.
GetAttr
(
attr_name_
);
const
Attribute
attr
=
op_desc
.
GetAttr
(
attr_name_
);
for
(
auto
&
func
:
conditions_
)
{
for
(
auto
&
func
:
conditions_
)
{
...
@@ -65,6 +87,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) {
...
@@ -65,6 +87,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) {
}
}
return
true
;
return
true
;
}
}
AttrCompat
&
AttrCompat
::
IsOptional
()
{
optional_
=
true
;
return
*
this
;
}
AttrCompat
&
AttrCompat
::
IsBoolEQ
(
bool
v
)
{
AttrCompat
&
AttrCompat
::
IsBoolEQ
(
bool
v
)
{
conditions_
.
emplace_back
([
v
](
const
Attribute
&
attr
)
->
bool
{
conditions_
.
emplace_back
([
v
](
const
Attribute
&
attr
)
->
bool
{
...
@@ -98,8 +124,12 @@ bool InputOrOutputCompat::operator()(
...
@@ -98,8 +124,12 @@ bool InputOrOutputCompat::operator()(
}
}
AttrCompat
&
OpCompat
::
AddAttr
(
const
std
::
string
&
attr_name
)
{
AttrCompat
&
OpCompat
::
AddAttr
(
const
std
::
string
&
attr_name
)
{
attr_compats_
.
emplace_back
(
attr_name
,
this
);
PADDLE_ENFORCE_EQ
(
return
attr_compats_
.
back
();
attr_compats_
.
find
(
attr_name
),
attr_compats_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"The attrubute compat with the same name has been added"
));
attr_compats_
.
emplace
(
attr_name
,
AttrCompat
(
attr_name
,
this
));
return
attr_compats_
.
at
(
attr_name
);
}
}
InputOrOutputCompat
&
OpCompat
::
AddInput
(
const
std
::
string
&
name
)
{
InputOrOutputCompat
&
OpCompat
::
AddInput
(
const
std
::
string
&
name
)
{
...
@@ -119,8 +149,19 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
...
@@ -119,8 +149,19 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
}
}
bool
OpCompat
::
Judge
(
const
OpDesc
&
op_desc
)
{
bool
OpCompat
::
Judge
(
const
OpDesc
&
op_desc
)
{
for
(
auto
&
attr_map
:
op_desc
.
GetAttrMap
())
{
if
(
attr_compats_
.
find
(
attr_map
.
first
)
==
attr_compats_
.
end
())
{
if
(
!
AttrCompat
(
attr_map
.
first
,
this
).
IsLeftDefault
()(
op_desc
))
{
VLOG
(
3
)
<<
"The Attr("
<<
attr_map
.
first
<<
") of Op ("
<<
op_name_
<<
") not reigistered in OpCompat, not equal to default value!"
;
return
false
;
}
}
}
for
(
auto
&
attr_compat
:
attr_compats_
)
{
for
(
auto
&
attr_compat
:
attr_compats_
)
{
if
(
!
attr_compat
(
op_desc
))
{
if
(
!
attr_compat
.
second
(
op_desc
))
{
VLOG
(
3
)
<<
" Check the Attr("
<<
attr_compat
.
first
<<
") of Op("
<<
op_name_
<<
") failed!"
;
return
false
;
return
false
;
}
}
}
}
...
@@ -129,6 +170,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
...
@@ -129,6 +170,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for
(
auto
&
input_desc
:
inputs_map
)
{
for
(
auto
&
input_desc
:
inputs_map
)
{
if
(
input_compats_
.
find
(
input_desc
.
first
)
==
input_compats_
.
end
())
{
if
(
input_compats_
.
find
(
input_desc
.
first
)
==
input_compats_
.
end
())
{
if
(
!
input_desc
.
second
.
empty
())
{
if
(
!
input_desc
.
second
.
empty
())
{
VLOG
(
3
)
<<
"The Input ("
<<
input_desc
.
first
<<
") of Operator ("
<<
op_name_
<<
") not reigistered in OpCompat!"
;
return
false
;
return
false
;
}
}
}
}
...
@@ -136,10 +179,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
...
@@ -136,10 +179,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for
(
auto
&
input_val
:
input_compats_
)
{
for
(
auto
&
input_val
:
input_compats_
)
{
if
(
inputs_map
.
find
(
input_val
.
first
)
==
inputs_map
.
end
())
{
if
(
inputs_map
.
find
(
input_val
.
first
)
==
inputs_map
.
end
())
{
if
(
!
input_val
.
second
.
Optional
())
{
if
(
!
input_val
.
second
.
Optional
())
{
VLOG
(
3
)
<<
"The No optional Input ("
<<
input_val
.
first
<<
") of Operator ("
<<
op_name_
<<
") not find in op_desc!"
;
return
false
;
return
false
;
}
}
}
else
{
}
else
{
if
(
!
input_val
.
second
(
inputs_map
.
at
(
input_val
.
first
)))
{
if
(
!
input_val
.
second
(
inputs_map
.
at
(
input_val
.
first
)))
{
VLOG
(
3
)
<<
"The Input ("
<<
input_val
.
first
<<
") of Operator ("
<<
op_name_
<<
") compat check failed!"
;
return
false
;
return
false
;
}
}
}
}
...
@@ -149,6 +196,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
...
@@ -149,6 +196,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for
(
auto
&
output_desc
:
outputs_map
)
{
for
(
auto
&
output_desc
:
outputs_map
)
{
if
(
output_compats_
.
find
(
output_desc
.
first
)
==
output_compats_
.
end
())
{
if
(
output_compats_
.
find
(
output_desc
.
first
)
==
output_compats_
.
end
())
{
if
(
!
output_desc
.
second
.
empty
())
{
if
(
!
output_desc
.
second
.
empty
())
{
VLOG
(
3
)
<<
"The Output ("
<<
output_desc
.
first
<<
") of Operator ("
<<
op_name_
<<
") not reigistered in OpCompat!"
;
return
false
;
return
false
;
}
}
}
}
...
@@ -156,10 +205,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
...
@@ -156,10 +205,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for
(
auto
&
output_val
:
output_compats_
)
{
for
(
auto
&
output_val
:
output_compats_
)
{
if
(
outputs_map
.
find
(
output_val
.
first
)
==
outputs_map
.
end
())
{
if
(
outputs_map
.
find
(
output_val
.
first
)
==
outputs_map
.
end
())
{
if
(
!
output_val
.
second
.
Optional
())
{
if
(
!
output_val
.
second
.
Optional
())
{
VLOG
(
3
)
<<
"The No optional Output ("
<<
output_val
.
first
<<
") of Operator ("
<<
op_name_
<<
") not find in op_desc!"
;
return
false
;
return
false
;
}
}
}
else
{
}
else
{
if
(
!
output_val
.
second
(
outputs_map
.
at
(
output_val
.
first
)))
{
if
(
!
output_val
.
second
(
outputs_map
.
at
(
output_val
.
first
)))
{
VLOG
(
3
)
<<
"The Output ("
<<
output_val
.
first
<<
") of Operator ("
<<
op_name_
<<
") compat check failed!"
;
return
false
;
return
false
;
}
}
}
}
...
...
paddle/fluid/framework/ir/op_compat_sensible_pass.h
浏览文件 @
dc72ffa5
...
@@ -29,7 +29,7 @@ class OpCompat;
...
@@ -29,7 +29,7 @@ class OpCompat;
class
AttrCompat
{
class
AttrCompat
{
public:
public:
AttrCompat
(
const
std
::
string
&
attr_name
,
OpCompat
*
op_compat
)
AttrCompat
(
const
std
::
string
&
attr_name
,
OpCompat
*
op_compat
)
:
attr_name_
(
attr_name
),
op_compat_
(
op_compat
)
{}
:
optional_
(
false
),
attr_name_
(
attr_name
),
op_compat_
(
op_compat
)
{}
// @{ String-related methods
// @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain.
//! Assert the attribute is an string in the `candidates` domain.
...
@@ -70,12 +70,15 @@ class AttrCompat {
...
@@ -70,12 +70,15 @@ class AttrCompat {
//! Tell whether this attribute is left as default value.
//! Tell whether this attribute is left as default value.
AttrCompat
&
IsLeftDefault
();
AttrCompat
&
IsLeftDefault
();
AttrCompat
&
IsOptional
();
//! Jump back to retrieve OpCompat instance.
//! Jump back to retrieve OpCompat instance.
OpCompat
&
End
()
{
return
*
op_compat_
;
}
OpCompat
&
End
()
{
return
*
op_compat_
;
}
bool
operator
()(
const
OpDesc
&
op_desc
);
bool
operator
()(
const
OpDesc
&
op_desc
);
private:
private:
bool
optional_
;
std
::
string
attr_name_
;
std
::
string
attr_name_
;
OpCompat
*
op_compat_
;
OpCompat
*
op_compat_
;
std
::
vector
<
std
::
function
<
bool
(
const
Attribute
&
)
>>
conditions_
;
std
::
vector
<
std
::
function
<
bool
(
const
Attribute
&
)
>>
conditions_
;
...
@@ -134,7 +137,7 @@ class OpCompat {
...
@@ -134,7 +137,7 @@ class OpCompat {
private:
private:
std
::
string
op_name_
;
std
::
string
op_name_
;
std
::
vector
<
AttrCompat
>
attr_compats_
;
std
::
unordered_map
<
std
::
string
,
AttrCompat
>
attr_compats_
;
std
::
unordered_map
<
std
::
string
,
InputOrOutputCompat
>
input_compats_
;
std
::
unordered_map
<
std
::
string
,
InputOrOutputCompat
>
input_compats_
;
std
::
unordered_map
<
std
::
string
,
InputOrOutputCompat
>
output_compats_
;
std
::
unordered_map
<
std
::
string
,
InputOrOutputCompat
>
output_compats_
;
};
};
...
@@ -179,15 +182,6 @@ class OpCompat {
...
@@ -179,15 +182,6 @@ class OpCompat {
* };
* };
*/
*/
class
OpCompatSensiblePass
:
public
Pass
{
class
OpCompatSensiblePass
:
public
Pass
{
public:
//! Access the subgraph and pattern.
void
AccessSubgraph
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
if
(
IsCompat
(
subgraph
,
g
))
{
AccessSubgraphImpl
(
subgraph
,
g
);
}
}
protected:
protected:
/**
/**
* Developer should push the compatibility `teller` for each kind of Op in the
* Developer should push the compatibility `teller` for each kind of Op in the
...
@@ -197,12 +191,6 @@ class OpCompatSensiblePass : public Pass {
...
@@ -197,12 +191,6 @@ class OpCompatSensiblePass : public Pass {
*/
*/
OpCompat
&
AddOpCompat
(
OpCompat
&&
op_compat
);
OpCompat
&
AddOpCompat
(
OpCompat
&&
op_compat
);
//! Modify the subgraph.
virtual
bool
AccessSubgraphImpl
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
const
{
return
true
;
}
//! Tell the Op compability of a subgraph.
//! Tell the Op compability of a subgraph.
bool
IsCompat
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
bool
IsCompat
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
const
{
Graph
*
g
)
const
{
...
@@ -212,7 +200,7 @@ class OpCompatSensiblePass : public Pass {
...
@@ -212,7 +200,7 @@ class OpCompatSensiblePass : public Pass {
// Check the all the ops in the subgraph are contained in the
// Check the all the ops in the subgraph are contained in the
// op_compat.
// op_compat.
for
(
auto
&
node_pair
:
subgraph
)
{
for
(
auto
&
node_pair
:
subgraph
)
{
if
(
!
node_pair
.
first
->
IsOp
())
continue
;
if
(
!
node_pair
.
second
->
IsOp
())
continue
;
auto
op_type
=
node_pair
.
second
->
Op
()
->
Type
();
auto
op_type
=
node_pair
.
second
->
Op
()
->
Type
();
if
(
!
op_compat_judgers_
.
count
(
op_type
))
{
if
(
!
op_compat_judgers_
.
count
(
op_type
))
{
return
false
;
return
false
;
...
...
paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc
浏览文件 @
dc72ffa5
...
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -23,7 +23,7 @@ namespace ir {
...
@@ -23,7 +23,7 @@ namespace ir {
TEST
(
OpCompatSensiblePass
,
compatOp
)
{
TEST
(
OpCompatSensiblePass
,
compatOp
)
{
auto
lambda
=
[](
const
std
::
string
&
str
)
{
return
str
==
"tanh"
;
};
auto
lambda
=
[](
const
std
::
string
&
str
)
{
return
str
==
"tanh"
;
};
OpCompat
compat
(
"
FC
"
);
OpCompat
compat
(
"
fc
"
);
compat
.
AddAttr
(
"in_num_col_dims"
)
compat
.
AddAttr
(
"in_num_col_dims"
)
.
IsIntIn
({
1
,
2
})
.
IsIntIn
({
1
,
2
})
.
IsNumLE
(
1
)
.
IsNumLE
(
1
)
...
@@ -67,10 +67,75 @@ TEST(OpCompatSensiblePass, compatOp) {
...
@@ -67,10 +67,75 @@ TEST(OpCompatSensiblePass, compatOp) {
fc_op
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
{
"test_input_1"
});
fc_op
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
{
"test_input_1"
});
fc_op
.
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
{
"test_output"
});
fc_op
.
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
{
"test_output"
});
EXPECT_STREQ
(
compat
.
Name
().
c_str
(),
"FC"
);
EXPECT_STREQ
(
compat
.
Name
().
c_str
(),
"fc"
);
EXPECT_FALSE
(
compat
.
Judge
(
fc_op
));
}
TEST
(
OpCompatSensiblePass
,
compatOpAttribute
)
{
OpCompat
compat
(
"fc"
);
OpDesc
fc_op
;
std
::
unordered_map
<
std
::
string
,
Attribute
>
attr_map
;
attr_map
[
"in_num_col_dims"
]
=
1
;
fc_op
.
SetAttrMap
(
attr_map
);
OpInfo
info
;
info
.
checker_
=
new
OpAttrChecker
();
OpInfoMap
::
Instance
().
Insert
(
"fc"
,
info
);
EXPECT_FALSE
(
compat
.
Judge
(
fc_op
));
info
.
checker_
->
AddAttrChecker
<
int
>
(
"in_num_col_dims"
).
SetDefault
(
1
);
EXPECT_TRUE
(
compat
.
Judge
(
fc_op
));
delete
info
.
checker_
;
}
TEST
(
OpCompatSensiblePass
,
compatOpAttributeOptional
)
{
OpCompat
compat
(
"fc"
);
compat
.
AddAttr
(
"activation_type"
)
.
IsOptional
()
.
IsStringIn
({
"tanh"
,
"sigmoid"
});
OpDesc
fc_op
;
EXPECT_TRUE
(
compat
.
Judge
(
fc_op
));
EXPECT_TRUE
(
compat
.
Judge
(
fc_op
));
}
}
TEST
(
OpCompatSensiblePass
,
compatOpInput
)
{
OpCompat
compat
(
"fc"
);
OpDesc
fc_op
;
fc_op
.
SetInput
(
"Input"
,
std
::
vector
<
std
::
string
>
{
"test_input"
});
EXPECT_FALSE
(
compat
.
Judge
(
fc_op
));
compat
.
AddInput
(
"Input"
).
IsTensor
().
End
().
AddInput
(
"Bias"
).
IsTensor
().
End
();
EXPECT_FALSE
(
compat
.
Judge
(
fc_op
));
fc_op
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
{
"test_input"
,
""
});
EXPECT_FALSE
(
compat
.
Judge
(
fc_op
));
}
TEST
(
OpCompatSensiblePass
,
compatOutput
)
{
OpCompat
compat
(
"fc"
);
OpDesc
fc_op
;
fc_op
.
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
{
"test_output"
});
EXPECT_FALSE
(
compat
.
Judge
(
fc_op
));
compat
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Output_2"
)
.
IsTensor
()
.
End
();
EXPECT_FALSE
(
compat
.
Judge
(
fc_op
));
fc_op
.
SetOutput
(
"Output_2"
,
std
::
vector
<
std
::
string
>
{
"test_output"
,
""
});
EXPECT_FALSE
(
compat
.
Judge
(
fc_op
));
}
class
OpCompatSensiblePassTest
:
public
OpCompatSensiblePass
{
class
OpCompatSensiblePassTest
:
public
OpCompatSensiblePass
{
public:
public:
OpCompatSensiblePassTest
();
OpCompatSensiblePassTest
();
...
@@ -78,7 +143,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass {
...
@@ -78,7 +143,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass {
};
};
OpCompatSensiblePassTest
::
OpCompatSensiblePassTest
()
{
OpCompatSensiblePassTest
::
OpCompatSensiblePassTest
()
{
AddOpCompat
(
OpCompat
(
"
FC
"
))
AddOpCompat
(
OpCompat
(
"
fc
"
))
.
AddAttr
(
"in_num_col_dims"
)
.
AddAttr
(
"in_num_col_dims"
)
.
IsNumLE
(
1
)
.
IsNumLE
(
1
)
.
End
()
.
End
()
...
@@ -102,7 +167,7 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
...
@@ -102,7 +167,7 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
TEST
(
OpCompatSensiblePass
,
IsCompat
)
{
TEST
(
OpCompatSensiblePass
,
IsCompat
)
{
OpCompatSensiblePassTest
test
;
OpCompatSensiblePassTest
test
;
OpDesc
fc_op
;
OpDesc
fc_op
;
fc_op
.
SetType
(
"
FC
"
);
fc_op
.
SetType
(
"
fc
"
);
std
::
unordered_map
<
std
::
string
,
Attribute
>
attr_map
;
std
::
unordered_map
<
std
::
string
,
Attribute
>
attr_map
;
attr_map
[
"in_num_col_dims"
]
=
1
;
attr_map
[
"in_num_col_dims"
]
=
1
;
attr_map
[
"activation_type"
]
=
std
::
string
(
"tanh"
);
attr_map
[
"activation_type"
]
=
std
::
string
(
"tanh"
);
...
@@ -114,18 +179,6 @@ TEST(OpCompatSensiblePass, IsCompat) {
...
@@ -114,18 +179,6 @@ TEST(OpCompatSensiblePass, IsCompat) {
fc_op
.
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
{
"test_output"
});
fc_op
.
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
{
"test_output"
});
EXPECT_TRUE
(
test
.
TestIsCompat
(
fc_op
));
EXPECT_TRUE
(
test
.
TestIsCompat
(
fc_op
));
ProgramDesc
prog
;
std
::
unique_ptr
<
Graph
>
g
(
new
Graph
(
prog
));
Node
*
o1
=
g
->
CreateOpNode
(
&
fc_op
);
GraphPatternDetector
detector
;
PDNode
*
op2
=
detector
.
mutable_pattern
()
->
NewNode
([](
Node
*
x
)
{
return
true
;
});
GraphPatternDetector
::
subgraph_t
subgraph
;
subgraph
[
op2
]
=
o1
;
test
.
AccessSubgraph
(
subgraph
,
g
.
get
());
}
}
}
// namespace ir
}
// namespace ir
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录