Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
98e85f37
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
98e85f37
编写于
1月 11, 2019
作者:
Z
Zhaolong Xing
提交者:
Yan Chunwei
1月 11, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add_transpose_flatten_concat_fuse (#15121)
上级
5d9edb41
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
311 addition
and
6 deletion
+311
-6
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+11
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+63
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+15
-0
paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc
.../fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc
+148
-0
paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h
...e/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h
+38
-0
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+1
-0
paddle/fluid/inference/api/paddle_pass_builder.h
paddle/fluid/inference/api/paddle_pass_builder.h
+4
-0
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
+31
-6
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
98e85f37
...
@@ -48,6 +48,17 @@ pass_library(conv_elementwise_add_act_fuse_pass inference)
...
@@ -48,6 +48,17 @@ pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library
(
conv_elementwise_add2_act_fuse_pass inference
)
pass_library
(
conv_elementwise_add2_act_fuse_pass inference
)
pass_library
(
conv_elementwise_add_fuse_pass inference
)
pass_library
(
conv_elementwise_add_fuse_pass inference
)
pass_library
(
conv_affine_channel_fuse_pass inference
)
pass_library
(
conv_affine_channel_fuse_pass inference
)
pass_library
(
transpose_flatten_concat_fuse_pass inference
)
# There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will
# be detected by our pass. The index here represents the number of structures in the
# pattern. We use index 3 ~ 6, because these quantities of structures are
# common in the models.
foreach
(
index RANGE 3 6
)
file
(
APPEND
${
pass_file
}
"USE_PASS(transpose_flatten
${
index
}
_concat_fuse_pass);
\n
"
)
endforeach
()
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
pass_library
(
mkldnn_placement_pass base
)
pass_library
(
mkldnn_placement_pass base
)
pass_library
(
depthwise_conv_mkldnn_pass base
)
pass_library
(
depthwise_conv_mkldnn_pass base
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
98e85f37
...
@@ -1306,6 +1306,69 @@ PDNode *patterns::ConvAffineChannel::operator()(
...
@@ -1306,6 +1306,69 @@ PDNode *patterns::ConvAffineChannel::operator()(
return
ac_out_var
;
return
ac_out_var
;
}
}
// a -> transpose_op(1) -> transpose_out_a -> flatten_op(1) -> flatten_out_a
// b -> transpose_op(2) -> transpose_out_b -> flatten_op(2) -> flatten_out_b
// ...
// z -> transpose_op(n) -> transpose_out_z -> flatten_op(n) -> flatten_out_z
// flatten_out_a -> concat_op flatten_out_b -> concat_op ... flatten_out_z ->
// concat_op
PDNode
*
patterns
::
TransposeFlattenConcat
::
operator
()(
std
::
vector
<
PDNode
*>
conv_in
,
int
times
)
{
// The times represents the repeat times of the
// {trans, trans_out, flatten, flatten_out}
const
int
kNumFields
=
4
;
const
int
kTransOutOffset
=
1
;
const
int
kFlattenOffset
=
2
;
const
int
kFlattenOutOffset
=
3
;
std
::
vector
<
PDNode
*>
nodes
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"transpose"
+
std
::
to_string
(
i
)))
->
assert_is_op
(
"transpose2"
));
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"transpose_out"
+
std
::
to_string
(
i
)))
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_op_input
(
"flatten2"
,
"X"
)
->
AsIntermediate
());
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"flatten"
+
std
::
to_string
(
i
)))
->
assert_is_op
(
"flatten2"
));
nodes
.
push_back
(
pattern
->
NewNode
(
GetNodeName
(
"flatten_out"
+
std
::
to_string
(
i
)))
->
assert_is_op_output
(
"flatten2"
)
->
assert_is_op_nth_input
(
"concat"
,
"X"
,
i
)
->
AsIntermediate
());
}
auto
concat_op
=
pattern
->
NewNode
(
GetNodeName
(
"concat"
))
->
assert_is_op
(
"concat"
)
->
assert_op_has_n_inputs
(
"concat"
,
times
);
auto
concat_out
=
pattern
->
NewNode
(
GetNodeName
(
"concat_out"
))
->
assert_is_op_output
(
"concat"
)
->
AsOutput
();
std
::
vector
<
PDNode
*>
flatten_outs
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
conv_in
[
i
]
->
AsInput
();
// trans
nodes
[
i
*
kNumFields
]
->
LinksFrom
({
conv_in
[
i
]});
// trans_out
nodes
[
i
*
kNumFields
+
kTransOutOffset
]
->
LinksFrom
({
nodes
[
i
*
kNumFields
]});
// flatten
nodes
[
i
*
kNumFields
+
kFlattenOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kTransOutOffset
]});
// flatten_out
nodes
[
i
*
kNumFields
+
kFlattenOutOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kFlattenOffset
]});
flatten_outs
.
push_back
(
nodes
[
i
*
kNumFields
+
kFlattenOutOffset
]);
}
concat_op
->
LinksFrom
(
flatten_outs
).
LinksTo
({
concat_out
});
return
concat_out
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
98e85f37
...
@@ -766,6 +766,21 @@ struct ConvAffineChannel : public PatternBase {
...
@@ -766,6 +766,21 @@ struct ConvAffineChannel : public PatternBase {
PATTERN_DECL_NODE
(
ac_out
);
// Out
PATTERN_DECL_NODE
(
ac_out
);
// Out
};
};
struct
TransposeFlattenConcat
:
public
PatternBase
{
TransposeFlattenConcat
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"transpose_flatten_concat"
)
{}
PDNode
*
operator
()(
std
::
vector
<
PDNode
*>
conv_inputs
,
int
times
);
std
::
string
GetNodeName
(
const
std
::
string
&
op_type
)
{
return
PDNodeName
(
name_scope_
,
repr_
,
id_
,
op_type
);
}
PDNode
*
GetPDNode
(
const
std
::
string
&
op_type
)
{
return
pattern
->
RetrieveNode
(
GetNodeName
(
op_type
));
}
};
}
// namespace patterns
}
// namespace patterns
// Link two ir::Nodes from each other.
// Link two ir::Nodes from each other.
...
...
paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc
0 → 100644
浏览文件 @
98e85f37
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
template
<
int
times
>
std
::
unique_ptr
<
ir
::
Graph
>
TransposeFlattenConcatFusePass
<
times
>::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
std
::
string
pattern_name
=
"transpose_flatten"
+
std
::
to_string
(
times
)
+
"_concat_fuse"
;
FusePassBase
::
Init
(
pattern_name
,
graph
.
get
());
GraphPatternDetector
gpd
;
std
::
vector
<
PDNode
*>
input_nodes
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
input_nodes
.
push_back
(
gpd
.
mutable_pattern
()
->
NewNode
(
"x"
+
std
::
to_string
(
i
))
->
assert_is_op_input
(
"transpose2"
,
"X"
)
->
AsInput
());
}
patterns
::
TransposeFlattenConcat
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
input_nodes
,
times
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
const
int
kNumFields
=
5
;
const
int
kTransOffset
=
1
;
const
int
kTransOutOffset
=
2
;
const
int
kFlattenOffset
=
3
;
const
int
kFlattenOutOffset
=
4
;
std
::
vector
<
Node
*>
nodes
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
PADDLE_ENFORCE
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"transpose"
+
std
::
to_string
(
i
))));
PADDLE_ENFORCE
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"transpose_out"
+
std
::
to_string
(
i
))));
PADDLE_ENFORCE
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"flatten"
+
std
::
to_string
(
i
))));
PADDLE_ENFORCE
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"flatten_out"
+
std
::
to_string
(
i
))));
PADDLE_ENFORCE
(
subgraph
.
at
(
input_nodes
[
i
]));
nodes
.
push_back
(
subgraph
.
at
(
input_nodes
[
i
]));
nodes
.
push_back
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"transpose"
+
std
::
to_string
(
i
))));
nodes
.
push_back
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"transpose_out"
+
std
::
to_string
(
i
))));
nodes
.
push_back
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"flatten"
+
std
::
to_string
(
i
))));
nodes
.
push_back
(
subgraph
.
at
(
pattern
.
GetPDNode
(
"flatten_out"
+
std
::
to_string
(
i
))));
}
Node
*
concat_op
=
subgraph
.
at
(
pattern
.
GetPDNode
(
"concat"
));
Node
*
concat_out
=
subgraph
.
at
(
pattern
.
GetPDNode
(
"concat_out"
));
std
::
vector
<
std
::
string
>
input_names
;
std
::
vector
<
int
>
trans_axis
=
boost
::
get
<
std
::
vector
<
int
>>
(
nodes
[
kTransOffset
]
->
Op
()
->
GetAttr
(
"axis"
));
int
flatten_axis
=
boost
::
get
<
int
>
(
nodes
[
kFlattenOffset
]
->
Op
()
->
GetAttr
(
"axis"
));
int
concat_axis
=
boost
::
get
<
int
>
(
concat_op
->
Op
()
->
GetAttr
(
"axis"
));
std
::
string
output_name
=
concat_out
->
Name
();
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
input_names
.
push_back
(
nodes
[
i
*
kNumFields
]
->
Name
());
}
framework
::
OpDesc
new_op_desc
;
new_op_desc
.
SetType
(
"fusion_transpose_flatten_concat"
);
new_op_desc
.
SetInput
(
"X"
,
input_names
);
new_op_desc
.
SetAttr
(
"trans_axis"
,
trans_axis
);
new_op_desc
.
SetAttr
(
"flatten_axis"
,
flatten_axis
);
new_op_desc
.
SetAttr
(
"concat_axis"
,
concat_axis
);
new_op_desc
.
SetOutput
(
"Out"
,
{
output_name
});
new_op_desc
.
Flush
();
// Create a new node for the fused op.
auto
*
new_conv_op
=
graph
->
CreateOpNode
(
&
new_op_desc
);
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
for
(
int
i
=
0
;
i
<
times
;
i
++
)
{
nodes
[
i
*
kNumFields
]
->
outputs
.
push_back
(
new_conv_op
);
new_conv_op
->
inputs
.
push_back
(
nodes
[
i
*
kNumFields
]);
delete_nodes
.
insert
(
nodes
[
i
*
kNumFields
+
kTransOffset
]);
delete_nodes
.
insert
(
nodes
[
i
*
kNumFields
+
kTransOutOffset
]);
delete_nodes
.
insert
(
nodes
[
i
*
kNumFields
+
kFlattenOffset
]);
delete_nodes
.
insert
(
nodes
[
i
*
kNumFields
+
kFlattenOutOffset
]);
}
delete_nodes
.
insert
(
concat_op
);
new_conv_op
->
outputs
.
push_back
(
concat_out
);
concat_out
->
inputs
.
push_back
(
new_conv_op
);
// Delete the unneeded nodes.
GraphSafeRemoveNodes
(
graph
.
get
(),
delete_nodes
);
};
gpd
(
graph
.
get
(),
handler
);
return
graph
;
}
template
class
TransposeFlattenConcatFusePass
<
1
>;
template
class
TransposeFlattenConcatFusePass
<
3
>;
template
class
TransposeFlattenConcatFusePass
<
4
>;
template
class
TransposeFlattenConcatFusePass
<
5
>;
template
class
TransposeFlattenConcatFusePass
<
6
>;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
transpose_flatten_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
1
>
);
REGISTER_PASS
(
transpose_flatten3_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
3
>
);
REGISTER_PASS
(
transpose_flatten4_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
4
>
);
REGISTER_PASS
(
transpose_flatten5_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
5
>
);
REGISTER_PASS
(
transpose_flatten6_concat_fuse_pass
,
paddle
::
framework
::
ir
::
TransposeFlattenConcatFusePass
<
6
>
);
paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h
0 → 100644
浏览文件 @
98e85f37
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// There may be many transpose-flatten structures in a model, and the output of
// these structures will be used as inputs to the concat Op. This pattern will
// be detected by our pass. The times here represents the repeat times of this
// structure.
template
<
int
times
>
class
TransposeFlattenConcatFusePass
:
public
FusePassBase
{
public:
virtual
~
TransposeFlattenConcatFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
98e85f37
...
@@ -127,6 +127,7 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
...
@@ -127,6 +127,7 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
use_tensorrt_
=
true
;
use_tensorrt_
=
true
;
tensorrt_workspace_size_
=
workspace_size
;
tensorrt_workspace_size_
=
workspace_size
;
tensorrt_max_batchsize_
=
max_batch_size
;
tensorrt_max_batchsize_
=
max_batch_size
;
Update
();
}
}
void
contrib
::
AnalysisConfig
::
Update
()
{
void
contrib
::
AnalysisConfig
::
Update
()
{
...
...
paddle/fluid/inference/api/paddle_pass_builder.h
浏览文件 @
98e85f37
...
@@ -141,6 +141,10 @@ class GpuPassStrategy : public PassStrategy {
...
@@ -141,6 +141,10 @@ class GpuPassStrategy : public PassStrategy {
"conv_elementwise_add_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
});
});
for
(
int
i
=
6
;
i
>=
3
;
i
--
)
{
passes_
.
push_back
(
"transpose_flatten"
+
std
::
to_string
(
i
)
+
"_concat_fuse_pass"
);
}
use_gpu_
=
true
;
use_gpu_
=
true
;
}
}
...
...
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
浏览文件 @
98e85f37
...
@@ -39,6 +39,7 @@ class ElementwiseWeightOpConverter : public OpConverter {
...
@@ -39,6 +39,7 @@ class ElementwiseWeightOpConverter : public OpConverter {
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
// Here the two nullptr looks strange, that's because the
// Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange.
// framework::OpDesc's constructor is strange.
nvinfer1
::
ILayer
*
layer
=
nullptr
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
VLOG
(
3
)
<<
"Convert a fluid elementwise op to TensorRT IScaleLayer"
;
VLOG
(
3
)
<<
"Convert a fluid elementwise op to TensorRT IScaleLayer"
;
...
@@ -98,13 +99,21 @@ class ElementwiseWeightOpConverter : public OpConverter {
...
@@ -98,13 +99,21 @@ class ElementwiseWeightOpConverter : public OpConverter {
0
};
0
};
TensorRTEngine
::
Weight
power_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
TensorRTEngine
::
Weight
power_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
0
};
if
(
op_type_
==
"add"
)
{
nvinfer1
::
IScaleLayer
*
scale_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Scale
,
*
X
,
scale_mode
,
shift_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
layer
=
scale_layer
;
}
else
if
(
op_type_
==
"mul"
)
{
nvinfer1
::
IScaleLayer
*
scale_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Scale
,
*
X
,
scale_mode
,
scale_weights
.
get
(),
shift_weights
.
get
(),
power_weights
.
get
());
layer
=
scale_layer
;
}
nvinfer1
::
IScaleLayer
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Scale
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
X
),
scale_mode
,
shift_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
(
layer
->
setName
((
"elementwise_add
(Output: "
+
output_name
+
")"
).
c_str
());
(
"elementwise_"
+
op_type_
+
"
(Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
weight_map
[
op_desc
.
Input
(
"Y"
).
front
()]
=
std
::
move
(
weight_tensor
);
engine_
->
weight_map
[
op_desc
.
Input
(
"Y"
).
front
()]
=
std
::
move
(
weight_tensor
);
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
...
@@ -113,6 +122,9 @@ class ElementwiseWeightOpConverter : public OpConverter {
...
@@ -113,6 +122,9 @@ class ElementwiseWeightOpConverter : public OpConverter {
engine_
->
DeclareOutput
(
output_name
);
engine_
->
DeclareOutput
(
output_name
);
}
}
}
}
protected:
std
::
string
op_type_
;
};
};
class
ElementwiseTensorOpConverter
:
public
OpConverter
{
class
ElementwiseTensorOpConverter
:
public
OpConverter
{
...
@@ -188,6 +200,16 @@ const std::unordered_map<std::string, nvinfer1::ElementWiseOperation>
...
@@ -188,6 +200,16 @@ const std::unordered_map<std::string, nvinfer1::ElementWiseOperation>
{
"max"
,
nvinfer1
::
ElementWiseOperation
::
kMAX
},
{
"max"
,
nvinfer1
::
ElementWiseOperation
::
kMAX
},
};
};
class
ElementwiseWeightAddOpConverter
:
public
ElementwiseWeightOpConverter
{
public:
ElementwiseWeightAddOpConverter
()
{
op_type_
=
"add"
;
}
};
class
ElementwiseWeightMulOpConverter
:
public
ElementwiseWeightOpConverter
{
public:
ElementwiseWeightMulOpConverter
()
{
op_type_
=
"mul"
;
}
};
class
ElementwiseTensorAddOpConverter
:
public
ElementwiseTensorOpConverter
{
class
ElementwiseTensorAddOpConverter
:
public
ElementwiseTensorOpConverter
{
public:
public:
ElementwiseTensorAddOpConverter
()
{
op_type_
=
"add"
;
}
ElementwiseTensorAddOpConverter
()
{
op_type_
=
"add"
;
}
...
@@ -227,7 +249,10 @@ class ElementwiseTensorPowOpConverter : public ElementwiseTensorOpConverter {
...
@@ -227,7 +249,10 @@ class ElementwiseTensorPowOpConverter : public ElementwiseTensorOpConverter {
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
elementwise_add_weight
,
ElementwiseWeightOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
elementwise_add_weight
,
ElementwiseWeightAddOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
elementwise_mul_weight
,
ElementwiseWeightMulOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
elementwise_add_tensor
,
REGISTER_TRT_OP_CONVERTER
(
elementwise_add_tensor
,
ElementwiseTensorAddOpConverter
);
ElementwiseTensorAddOpConverter
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录