Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c9995289
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c9995289
编写于
9月 14, 2018
作者:
Z
Zhaolong Xing
提交者:
GitHub
9月 14, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13124 from NHZlX/fix_subgraph_bug
Fix tensorrt subgraph bug
上级
d4a5326a
8fb33c8a
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
227 addition
and
22 deletion
+227
-22
paddle/fluid/inference/analysis/data_flow_graph.cc
paddle/fluid/inference/analysis/data_flow_graph.cc
+5
-1
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+15
-10
paddle/fluid/inference/analysis/subgraph_splitter.cc
paddle/fluid/inference/analysis/subgraph_splitter.cc
+189
-5
paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/activation_op.cc
paddle/fluid/inference/tensorrt/convert/activation_op.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc
paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/concat_op.cc
paddle/fluid/inference/tensorrt/convert/concat_op.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
+4
-0
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
+2
-0
paddle/fluid/operators/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt_engine_op.h
+1
-5
未找到文件。
paddle/fluid/inference/analysis/data_flow_graph.cc
浏览文件 @
c9995289
...
...
@@ -440,6 +440,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
}
return
false
;
};
for
(
auto
&
node
:
graph
)
{
for
(
auto
*
in
:
node
->
inlinks
)
{
// The Value that is written by nodes inside a sub-graph shouldn't be the
...
...
@@ -459,6 +460,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
std
::
vector
<
Node
*>
(
outputs
.
begin
(),
outputs
.
end
()));
}
// Filter the Intermediate results of the subgraph node.
void
FilterRedundantOutputOfSubGraph
(
DataFlowGraph
*
graph
)
{
std
::
vector
<
Node
*>
op_nodes
;
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
*
graph
).
nodes_in_TS
())
{
...
...
@@ -480,9 +482,11 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
for
(
auto
*
out
:
op_nodes
[
i
]
->
outlinks
)
{
if
(
follow_up_input_names
.
count
(
out
->
name
()))
{
filtered_subgraph_outlinks
.
push_back
(
out
);
}
else
{
out
->
SetDeleted
();
}
}
PADDLE_ENFORCE_GE
(
filtered_subgraph_outlinks
.
size
(),
1UL
);
// The filtered_subgraph_outlinks may be empty.
op_nodes
[
i
]
->
outlinks
=
filtered_subgraph_outlinks
;
}
}
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
c9995289
...
...
@@ -106,20 +106,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
// collect inputs
std
::
unordered_set
<
std
::
string
>
input_names
;
std
::
unordered_set
<
std
::
string
>
input_names_with_id
;
for
(
auto
*
x
:
func
->
inlinks
)
{
input_names
.
insert
(
x
->
name
());
input_names_with_id
.
insert
(
x
->
name
()
+
std
::
to_string
(
x
->
id
()));
}
desc
.
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
(
input_names
.
begin
(),
input_names
.
end
()));
std
::
unordered_set
<
std
::
string
>
output_names
;
std
::
unordered_set
<
std
::
string
>
output_names_with_id
;
for
(
auto
*
x
:
func
->
outlinks
)
{
output_names
.
insert
(
x
->
name
());
output_names_with_id
.
insert
(
x
->
name
()
+
std
::
to_string
(
x
->
id
()));
}
std
::
vector
<
std
::
string
>
output_temp
(
output_names
.
begin
(),
output_names
.
end
());
desc
.
SetOutput
(
"Ys"
,
output_temp
);
desc
.
SetOutput
(
"Ys"
,
std
::
vector
<
std
::
string
>
(
output_names
.
begin
(),
output_names
.
end
()));
desc
.
SetType
(
"tensorrt_engine"
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
output_name_map
;
...
...
@@ -153,11 +156,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
in_var
->
arguments_size
();
k
++
)
{
std
::
string
arg_value
=
in_var
->
arguments
(
k
);
if
(
input_names
.
count
(
arg_value
))
{
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
input_names_with_id
.
count
(
arg_value_with_id
))
{
replaced_names
.
push_back
(
arg_value
);
}
else
{
replaced_names
.
push_back
(
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]));
replaced_names
.
push_back
(
arg_value_with_id
);
}
}
in_var
->
clear_arguments
();
...
...
@@ -176,11 +180,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
out_var
->
arguments_size
();
k
++
)
{
std
::
string
arg_value
=
out_var
->
arguments
(
k
);
if
(
output_names
.
count
(
arg_value
))
{
output_name_map
[
arg_value
]
=
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
output_names_with_id
.
count
(
arg_value_with_id
))
{
output_name_map
[
arg_value
]
=
arg_value_with_id
;
}
replaced_names
.
push_back
(
arg_value
+
std
::
to_string
(
var2id
[
arg_value
])
);
replaced_names
.
push_back
(
arg_value
_with_id
);
}
out_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
...
...
paddle/fluid/inference/analysis/subgraph_splitter.cc
浏览文件 @
c9995289
...
...
@@ -74,13 +74,134 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
node_map
.
at
(
b
)
->
attr
(
kUnionFindParent
).
Int32
()
=
a_ancestor
;
}
// This is a simple representation of a graph.
// The BriefNode hold the pointer of the Node.
// This is to avoid changing the original graph
// in the process of trt graph analysis.
struct
BriefNode
{
explicit
BriefNode
(
Node
*
n
)
{
node
=
n
;
}
Node
*
node
;
std
::
vector
<
BriefNode
*>
inlinks
;
std
::
vector
<
BriefNode
*>
outlinks
;
};
// Union two adjacent BriefNode.
// Suppose we have two adjacent nodes src and dst.
// We will perform the following operations:
// 1. add all inputs(except src) of dst to src inlinks.
// 2. add all outputs of dst to src outlinks.
// 3. change all the dst's inputs and outputs
// corresponding inlinks and outlinks to src node.
// 4. delete all dst's inlinks and outlinks.
void
UnionContractedNodes
(
const
std
::
unordered_map
<
int
,
BriefNode
*>
&
node_map
,
int
src_id
,
int
dst_id
)
{
// merge the two adjacent nodes into one node.
BriefNode
*
src_node
=
node_map
.
at
(
src_id
);
BriefNode
*
dst_node
=
node_map
.
at
(
dst_id
);
std
::
unordered_set
<
BriefNode
*>
inputs
(
src_node
->
inlinks
.
begin
(),
src_node
->
inlinks
.
end
());
std
::
unordered_set
<
BriefNode
*>
outputs
;
for
(
auto
*
n
:
src_node
->
outlinks
)
{
if
(
n
!=
dst_node
)
outputs
.
insert
(
n
);
}
// Add the inlinks and outlinks of dst node to src node.
std
::
vector
<
BriefNode
*>
dst_in_nodes
=
dst_node
->
inlinks
;
for
(
BriefNode
*
node
:
dst_in_nodes
)
{
if
(
node
!=
src_node
)
{
inputs
.
insert
(
node
);
}
}
std
::
vector
<
BriefNode
*>
dst_out_nodes
=
dst_node
->
outlinks
;
for
(
BriefNode
*
node
:
dst_out_nodes
)
{
outputs
.
insert
(
node
);
}
// update the dst and src node's inlinks and outlinks.
src_node
->
inlinks
=
std
::
move
(
std
::
vector
<
BriefNode
*>
(
inputs
.
begin
(),
inputs
.
end
()));
src_node
->
outlinks
=
std
::
move
(
std
::
vector
<
BriefNode
*>
(
outputs
.
begin
(),
outputs
.
end
()));
dst_node
->
inlinks
.
clear
();
dst_node
->
outlinks
.
clear
();
auto
inlink_or_outlink_cleaner
=
[
&
](
std
::
vector
<
BriefNode
*>
&
nodes
)
{
for
(
auto
*&
n
:
nodes
)
{
if
(
n
==
src_node
||
n
==
dst_node
)
{
n
=
src_node
;
}
}
};
// Change all the dst inputs and outputs corresponding inlink and
// outlink to the src node.
for
(
auto
*
node
:
src_node
->
inlinks
)
{
inlink_or_outlink_cleaner
(
node
->
outlinks
);
}
for
(
auto
*
node
:
src_node
->
outlinks
)
{
inlink_or_outlink_cleaner
(
node
->
inlinks
);
}
}
// FlexibleDFS
// If reverse is true, do reverse dfs.
// If enter func is not nullptr, calls enter(node) before visiting any children
// of node.
// If leave func not nullptr, calls leave(node) after visiting all parents of
// node.
void
FlexibleDFS
(
const
std
::
vector
<
BriefNode
*>
&
source
,
bool
reverse
,
const
std
::
function
<
bool
(
const
BriefNode
*
)
>
&
enter
,
const
std
::
function
<
bool
(
const
BriefNode
*
)
>
&
leave
)
{
typedef
struct
{
const
BriefNode
*
node
;
bool
leave
;
}
FNode
;
std
::
vector
<
FNode
>
stack
;
for
(
auto
&
node
:
source
)
{
stack
.
push_back
(
FNode
{
node
,
false
});
}
std
::
unordered_set
<
const
BriefNode
*>
visited
;
while
(
!
stack
.
empty
())
{
auto
fnode
=
stack
.
back
();
stack
.
pop_back
();
if
(
fnode
.
leave
)
{
if
(
leave
&&
!
leave
(
fnode
.
node
))
return
;
}
if
(
visited
.
count
(
fnode
.
node
))
continue
;
visited
.
insert
(
fnode
.
node
);
if
(
enter
&&
!
enter
(
fnode
.
node
))
return
;
if
(
leave
)
stack
.
push_back
(
FNode
{
fnode
.
node
,
true
});
const
std
::
vector
<
BriefNode
*>
iter_nodes
=
reverse
==
true
?
fnode
.
node
->
inlinks
:
fnode
.
node
->
outlinks
;
for
(
const
BriefNode
*
node
:
iter_nodes
)
{
if
(
!
visited
.
count
(
node
))
{
stack
.
push_back
(
FNode
{
node
,
false
});
}
}
}
}
std
::
vector
<
std
::
vector
<
Node
*>>
SubGraphSplitter
::
ExtractSubGraphs
()
{
// Run the Extract algorithm to find all subgraphs.
std
::
vector
<
Node
*>
marked_nodes
;
// We use brief_node_map to represent the original graph in order to avoid
// changing the original graph.
std
::
unordered_map
<
int
,
BriefNode
*>
brief_node_map
;
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
*
graph_
).
nodes_in_TS
())
{
brief_node_map
[
node
.
id
()]
=
new
BriefNode
(
&
node
);
if
(
node
.
attr
(
kMarkerAttrName
).
Bool
())
{
marked_nodes
.
push_back
(
&
node
);
}
}
// extract sub-graphs in the marked node set, use Union Find algorithm.
node_map_t
node_map
;
// id to ptr
for
(
auto
*
n
:
marked_nodes
)
{
...
...
@@ -88,11 +209,73 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
n
->
attr
(
kUnionFindParent
).
Int32
()
=
n
->
id
();
node_map
[
n
->
id
()]
=
n
;
}
std
::
unordered_set
<
Node
*>
visited
;
for
(
auto
*
n
:
marked_nodes
)
{
for
(
auto
*
out
:
n
->
outlinks
)
{
if
(
node_map
.
count
(
out
->
id
()))
{
UnionFindCombine
(
node_map
,
n
->
id
(),
out
->
id
());
// create breif node map
for
(
auto
&
itr
:
brief_node_map
)
{
for
(
Node
*
node
:
itr
.
second
->
node
->
inlinks
)
{
itr
.
second
->
inlinks
.
push_back
(
brief_node_map
[
node
->
id
()]);
}
for
(
Node
*
node
:
itr
.
second
->
node
->
outlinks
)
{
itr
.
second
->
outlinks
.
push_back
(
brief_node_map
[
node
->
id
()]);
}
}
for
(
auto
&
itr
:
brief_node_map
)
{
BriefNode
*
brief_node
=
itr
.
second
;
if
(
!
brief_node
->
node
->
attr
(
kMarkerAttrName
).
Bool
())
{
VLOG
(
4
)
<<
brief_node
->
node
->
id
()
<<
" node not a trt candicate."
;
continue
;
}
// Our algorithm must guarantee that:
// 1. The graph is always directed acyclic graph(DAG).
// 2. If there is a path in the subgraph from X to Y (X and Y are both
// nodes in the subgraph), then all paths from X to Y are in the
// subgraph.
//
// In order to achieve the above guarantee.
// For adjacent nodes src -> dst.
// 1. Get all dst input nodes except src.
// 2. Reverse DFS from those input nodes
// 3. If there is a path from input nodes to src,
// then the src and dst nodes can not be fused into one node,
// otherwise it can be done.
while
(
true
)
{
std
::
unordered_set
<
BriefNode
*>
contract_nodes
;
for
(
auto
*
out
:
brief_node
->
outlinks
)
{
// must be an trt candidate
if
(
!
out
->
node
->
attr
(
kMarkerAttrName
).
Bool
())
continue
;
// get all dst input nodes except src.
std
::
vector
<
BriefNode
*>
source_nodes
;
for
(
auto
*
n
:
out
->
inlinks
)
{
if
(
n
!=
brief_node
)
{
source_nodes
.
push_back
(
n
);
}
}
// Reverse DFS from the source_nodes.
bool
have_excess_path
=
false
;
FlexibleDFS
(
source_nodes
,
true
,
nullptr
,
[
&
have_excess_path
,
brief_node
](
const
BriefNode
*
n
)
{
if
(
n
==
brief_node
)
{
have_excess_path
=
true
;
return
false
;
}
return
true
;
});
if
(
have_excess_path
)
continue
;
contract_nodes
.
insert
(
out
);
}
if
(
contract_nodes
.
empty
())
break
;
for
(
auto
dst_node
:
contract_nodes
)
{
UnionFindCombine
(
node_map
,
brief_node
->
node
->
id
(),
dst_node
->
node
->
id
());
UnionContractedNodes
(
brief_node_map
,
brief_node
->
node
->
id
(),
dst_node
->
node
->
id
());
}
}
}
...
...
@@ -128,6 +311,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
auto
io
=
ExtractInputAndOutputOfSubGraph
(
subgraph
);
block_node
->
inlinks
=
std
::
move
(
io
.
first
);
block_node
->
outlinks
=
std
::
move
(
io
.
second
);
for
(
auto
*
node
:
subgraph
)
{
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
// pass.
...
...
paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
浏览文件 @
c9995289
...
...
@@ -82,7 +82,7 @@ TEST(SubGraphSplitter, Fuse) {
// At least one nodes should be deleted.
ASSERT_EQ
(
dfg
.
nodes
.
size
(),
count0
+
1
);
// added a new FunctionBlock
ASSERT_EQ
(
6
,
count1
);
ASSERT_EQ
(
11
,
count1
);
}
}
// namespace analysis
...
...
paddle/fluid/inference/tensorrt/convert/activation_op.cc
浏览文件 @
c9995289
...
...
@@ -35,6 +35,8 @@ class ReluOpConverter : public OpConverter {
engine_
,
Activation
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input_tensor
),
nvinfer1
::
ActivationType
::
kRELU
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"relu (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
// the test framework can not determine which is the
// output, so place the declaration inside.
...
...
paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc
浏览文件 @
c9995289
...
...
@@ -116,6 +116,8 @@ class BatchNormOpConverter : public OpConverter {
scale_weights
.
get
(),
power_weights
.
get
());
auto
output_name
=
op_desc
.
Output
(
"Y"
).
front
();
layer
->
setName
((
"batch_norm (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
weight_map
[
op_desc
.
Input
(
"Bias"
).
front
()]
=
std
::
move
(
combile_bias_tensor
);
engine_
->
weight_map
[
op_desc
.
Input
(
"Scale"
).
front
()]
=
...
...
paddle/fluid/inference/tensorrt/convert/concat_op.cc
浏览文件 @
c9995289
...
...
@@ -42,6 +42,8 @@ class ConcatOpConverter : public OpConverter {
axis
=
axis
-
1
;
// Remove batch dim
layer
->
setAxis
(
axis
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"concat (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
// the test framework can not determine which is the
// output, so place the declaration inside.
...
...
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
c9995289
...
...
@@ -78,8 +78,10 @@ class Conv2dOpConverter : public OpConverter {
layer
->
setNbGroups
(
groups
);
auto
output_name
=
op_desc
.
Output
(
"Output"
).
front
();
layer
->
setName
((
"conv2d (Output: "
+
output_name
+
")"
).
c_str
());
engine_
->
weight_map
[
op_desc
.
Input
(
"Filter"
).
front
()]
=
std
::
move
(
weight_tensor
);
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
...
...
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
浏览文件 @
c9995289
...
...
@@ -89,6 +89,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
shift_weights
.
get
(),
scale_weights
.
get
(),
power_weights
.
get
());
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"elementwise_add (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
weight_map
[
op_desc
.
Input
(
"Y"
).
front
()]
=
std
::
move
(
weight_tensor
);
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
// the test framework can not determine which is the
...
...
@@ -137,6 +139,8 @@ class ElementwiseTensorOpConverter : public OpConverter {
*
const_cast
<
nvinfer1
::
ITensor
*>
(
Y
),
op_pair
->
second
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"elementwise (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
// the test framework can not determine which is the
// output, so place the declaration inside.
...
...
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
c9995289
...
...
@@ -107,6 +107,8 @@ class FcOpConverter : public OpConverter {
n_output
,
tmp_weight
.
get
(),
bias
.
get
());
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
layer
->
setName
((
"fc (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
engine_
->
weight_map
[
op_desc
.
Input
(
"Y"
).
front
()]
=
std
::
move
(
tmp
);
if
(
test_mode
)
{
...
...
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
浏览文件 @
c9995289
...
...
@@ -72,6 +72,8 @@ class Pool2dOpConverter : public OpConverter {
layer
->
setPadding
(
nv_paddings
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"pool2d (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
...
...
paddle/fluid/operators/tensorrt_engine_op.h
浏览文件 @
c9995289
...
...
@@ -160,11 +160,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
fluid_t
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
()).
device
)),
size
*
sizeof
(
float
));
//} else {
// engine->GetOutputInGPU(
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
// size * sizeof(float));
//}
output_index
+=
1
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录