Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
03ff4f68
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
03ff4f68
编写于
9月 11, 2018
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix subgraph bug!
上级
5ec2fb0c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
215 addition
and
60 deletion
+215
-60
paddle/fluid/inference/analysis/data_flow_graph.cc
paddle/fluid/inference/analysis/data_flow_graph.cc
+3
-36
paddle/fluid/inference/analysis/data_flow_graph.h
paddle/fluid/inference/analysis/data_flow_graph.h
+0
-3
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+15
-10
paddle/fluid/inference/analysis/subgraph_splitter.cc
paddle/fluid/inference/analysis/subgraph_splitter.cc
+181
-5
paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
+1
-1
paddle/fluid/operators/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt_engine_op.h
+15
-5
未找到文件。
paddle/fluid/inference/analysis/data_flow_graph.cc
浏览文件 @
03ff4f68
...
@@ -440,6 +440,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
...
@@ -440,6 +440,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
}
}
return
false
;
return
false
;
};
};
for
(
auto
&
node
:
graph
)
{
for
(
auto
&
node
:
graph
)
{
for
(
auto
*
in
:
node
->
inlinks
)
{
for
(
auto
*
in
:
node
->
inlinks
)
{
// The Value that is written by nodes inside a sub-graph shouldn't be the
// The Value that is written by nodes inside a sub-graph shouldn't be the
...
@@ -459,6 +460,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
...
@@ -459,6 +460,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
std
::
vector
<
Node
*>
(
outputs
.
begin
(),
outputs
.
end
()));
std
::
vector
<
Node
*>
(
outputs
.
begin
(),
outputs
.
end
()));
}
}
// Filter the Intermediate results of the subgraph node.
void
FilterRedundantOutputOfSubGraph
(
DataFlowGraph
*
graph
)
{
void
FilterRedundantOutputOfSubGraph
(
DataFlowGraph
*
graph
)
{
std
::
vector
<
Node
*>
op_nodes
;
std
::
vector
<
Node
*>
op_nodes
;
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
*
graph
).
nodes_in_TS
())
{
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
*
graph
).
nodes_in_TS
())
{
...
@@ -484,46 +486,11 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
...
@@ -484,46 +486,11 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
out
->
SetDeleted
();
out
->
SetDeleted
();
}
}
}
}
PADDLE_ENFORCE_GE
(
filtered_subgraph_outlinks
.
size
(),
1UL
);
// The filtered_subgraph_outlinks may be empty.
op_nodes
[
i
]
->
outlinks
=
filtered_subgraph_outlinks
;
op_nodes
[
i
]
->
outlinks
=
filtered_subgraph_outlinks
;
}
}
}
}
void
FlexibleDFS
(
const
std
::
vector
<
Node
*>
&
source
,
bool
reverse
,
const
std
::
function
<
bool
(
const
Node
*
)
>
&
enter
,
const
std
::
function
<
bool
(
const
Node
*
)
>
&
leave
)
{
typedef
struct
{
const
Node
*
node
;
bool
leave
;
}
FNode
;
std
::
vector
<
FNode
>
stack
;
for
(
auto
&
node
:
source
)
{
stack
.
push_back
(
FNode
{
node
,
false
});
}
std
::
unordered_set
<
const
Node
*>
visited
;
while
(
!
stack
.
empty
())
{
auto
fnode
=
stack
.
back
();
stack
.
pop_back
();
if
(
fnode
.
leave
)
{
if
(
leave
&&
!
leave
(
fnode
.
node
))
return
;
}
if
(
visited
.
count
(
fnode
.
node
))
continue
;
visited
.
insert
(
fnode
.
node
);
if
(
enter
&&
!
enter
(
fnode
.
node
))
return
;
if
(
leave
)
stack
.
push_back
(
FNode
{
fnode
.
node
,
true
});
const
std
::
vector
<
Node
*>
iter_nodes
=
reverse
==
true
?
fnode
.
node
->
inlinks
:
fnode
.
node
->
outlinks
;
for
(
const
Node
*
node
:
iter_nodes
)
{
if
(
!
visited
.
count
(
node
))
{
stack
.
push_back
(
FNode
{
node
,
false
});
}
}
}
}
}
// namespace analysis
}
// namespace analysis
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/analysis/data_flow_graph.h
浏览文件 @
03ff4f68
...
@@ -204,9 +204,6 @@ std::pair<std::vector<Node *>, std::vector<Node *>>
...
@@ -204,9 +204,6 @@ std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph
(
std
::
vector
<
Node
*>
&
graph
);
// NOLINT
ExtractInputAndOutputOfSubGraph
(
std
::
vector
<
Node
*>
&
graph
);
// NOLINT
void
FilterRedundantOutputOfSubGraph
(
DataFlowGraph
*
graph
);
void
FilterRedundantOutputOfSubGraph
(
DataFlowGraph
*
graph
);
void
FlexibleDFS
(
const
std
::
vector
<
Node
*>
&
source
,
bool
reverse
,
const
std
::
function
<
bool
(
const
Node
*
)
>
&
enter
,
const
std
::
function
<
bool
(
const
Node
*
)
>
&
leave
);
}
// namespace analysis
}
// namespace analysis
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
03ff4f68
...
@@ -106,20 +106,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
...
@@ -106,20 +106,23 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
// collect inputs
// collect inputs
std
::
unordered_set
<
std
::
string
>
input_names
;
std
::
unordered_set
<
std
::
string
>
input_names
;
std
::
unordered_set
<
std
::
string
>
input_names_with_id
;
for
(
auto
*
x
:
func
->
inlinks
)
{
for
(
auto
*
x
:
func
->
inlinks
)
{
input_names
.
insert
(
x
->
name
());
input_names
.
insert
(
x
->
name
());
input_names_with_id
.
insert
(
x
->
name
()
+
std
::
to_string
(
x
->
id
()));
}
}
desc
.
SetInput
(
desc
.
SetInput
(
"Xs"
,
std
::
vector
<
std
::
string
>
(
input_names
.
begin
(),
input_names
.
end
()));
"Xs"
,
std
::
vector
<
std
::
string
>
(
input_names
.
begin
(),
input_names
.
end
()));
std
::
unordered_set
<
std
::
string
>
output_names
;
std
::
unordered_set
<
std
::
string
>
output_names
;
std
::
unordered_set
<
std
::
string
>
output_names_with_id
;
for
(
auto
*
x
:
func
->
outlinks
)
{
for
(
auto
*
x
:
func
->
outlinks
)
{
output_names
.
insert
(
x
->
name
());
output_names
.
insert
(
x
->
name
());
output_names_with_id
.
insert
(
x
->
name
()
+
std
::
to_string
(
x
->
id
()));
}
}
std
::
vector
<
std
::
string
>
output_temp
(
output_names
.
begin
(),
desc
.
SetOutput
(
output_names
.
end
());
"Ys"
,
std
::
vector
<
std
::
string
>
(
output_names
.
begin
(),
output_names
.
end
()));
desc
.
SetOutput
(
"Ys"
,
output_temp
);
desc
.
SetType
(
"tensorrt_engine"
);
desc
.
SetType
(
"tensorrt_engine"
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
output_name_map
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
output_name_map
;
...
@@ -153,11 +156,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
...
@@ -153,11 +156,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
std
::
vector
<
std
::
string
>
replaced_names
;
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
in_var
->
arguments_size
();
k
++
)
{
for
(
int
k
=
0
;
k
<
in_var
->
arguments_size
();
k
++
)
{
std
::
string
arg_value
=
in_var
->
arguments
(
k
);
std
::
string
arg_value
=
in_var
->
arguments
(
k
);
if
(
input_names
.
count
(
arg_value
))
{
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
input_names_with_id
.
count
(
arg_value_with_id
))
{
replaced_names
.
push_back
(
arg_value
);
replaced_names
.
push_back
(
arg_value
);
}
else
{
}
else
{
replaced_names
.
push_back
(
arg_value
+
replaced_names
.
push_back
(
arg_value_with_id
);
std
::
to_string
(
var2id
[
arg_value
]));
}
}
}
}
in_var
->
clear_arguments
();
in_var
->
clear_arguments
();
...
@@ -176,11 +180,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
...
@@ -176,11 +180,12 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
std
::
vector
<
std
::
string
>
replaced_names
;
std
::
vector
<
std
::
string
>
replaced_names
;
for
(
int
k
=
0
;
k
<
out_var
->
arguments_size
();
k
++
)
{
for
(
int
k
=
0
;
k
<
out_var
->
arguments_size
();
k
++
)
{
std
::
string
arg_value
=
out_var
->
arguments
(
k
);
std
::
string
arg_value
=
out_var
->
arguments
(
k
);
if
(
output_names
.
count
(
arg_value
))
{
std
::
string
arg_value_with_id
=
output_name_map
[
arg_value
]
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
output_names_with_id
.
count
(
arg_value_with_id
))
{
output_name_map
[
arg_value
]
=
arg_value_with_id
;
}
}
replaced_names
.
push_back
(
arg_value
+
std
::
to_string
(
var2id
[
arg_value
])
);
replaced_names
.
push_back
(
arg_value
_with_id
);
}
}
out_var
->
clear_arguments
();
out_var
->
clear_arguments
();
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
for
(
size_t
k
=
0
;
k
<
replaced_names
.
size
();
k
++
)
{
...
...
paddle/fluid/inference/analysis/subgraph_splitter.cc
浏览文件 @
03ff4f68
...
@@ -74,13 +74,126 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
...
@@ -74,13 +74,126 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
node_map
.
at
(
b
)
->
attr
(
kUnionFindParent
).
Int32
()
=
a_ancestor
;
node_map
.
at
(
b
)
->
attr
(
kUnionFindParent
).
Int32
()
=
a_ancestor
;
}
}
// This is a simple representation of a graph.
// The BriefNode hold the pointer of the Node.
// This is to avoid changing the original graph
// in the process of trt graph analysis.
struct
BriefNode
{
explicit
BriefNode
(
Node
*
n
)
{
node
=
n
;
}
Node
*
node
;
std
::
vector
<
BriefNode
*>
inlinks
;
std
::
vector
<
BriefNode
*>
outlinks
;
};
void
UnionContractedNodes
(
const
std
::
unordered_map
<
int
,
BriefNode
*>
&
node_map
,
int
src_id
,
int
dst_id
)
{
// merge the two adjacent nodes into one node.
BriefNode
*
src_node
=
node_map
.
at
(
src_id
);
BriefNode
*
dst_node
=
node_map
.
at
(
dst_id
);
std
::
unordered_set
<
BriefNode
*>
inputs
(
src_node
->
inlinks
.
begin
(),
src_node
->
inlinks
.
end
());
std
::
unordered_set
<
BriefNode
*>
outputs
;
for
(
auto
*
n
:
src_node
->
outlinks
)
{
if
(
n
!=
dst_node
)
outputs
.
insert
(
n
);
}
// Add the inlinks and outlinks of dst node to src node.
std
::
vector
<
BriefNode
*>
dst_in_nodes
=
dst_node
->
inlinks
;
for
(
BriefNode
*
node
:
dst_in_nodes
)
{
if
(
node
!=
src_node
)
{
inputs
.
insert
(
node
);
}
}
std
::
vector
<
BriefNode
*>
dst_out_nodes
=
dst_node
->
outlinks
;
for
(
BriefNode
*
node
:
dst_out_nodes
)
{
outputs
.
insert
(
node
);
}
// update the dst and src node's inlinks and outlinks.
src_node
->
inlinks
=
std
::
move
(
std
::
vector
<
BriefNode
*>
(
inputs
.
begin
(),
inputs
.
end
()));
src_node
->
outlinks
=
std
::
move
(
std
::
vector
<
BriefNode
*>
(
outputs
.
begin
(),
outputs
.
end
()));
dst_node
->
inlinks
.
clear
();
dst_node
->
outlinks
.
clear
();
auto
inlink_or_outlink_cleaner
=
[
&
](
std
::
vector
<
BriefNode
*>
&
nodes
)
{
for
(
auto
*&
n
:
nodes
)
{
if
(
n
==
src_node
||
n
==
dst_node
)
{
n
=
src_node
;
}
}
};
// Change all the dst inputs and outputs corresponding inlink and
// outlink to the src node.
for
(
auto
*
node
:
src_node
->
inlinks
)
{
inlink_or_outlink_cleaner
(
node
->
outlinks
);
}
for
(
auto
*
node
:
src_node
->
outlinks
)
{
inlink_or_outlink_cleaner
(
node
->
inlinks
);
}
}
// FlexibleDfS
// If reverse is true, do reverse dfs.
// If enter func is not nullptr, calls enter(node) before visiting any children
// of node.
// If leave func not nullptr, calls leave(node) after visiting all parents of
// node.
void
FlexibleDFS
(
const
std
::
vector
<
BriefNode
*>
&
source
,
bool
reverse
,
const
std
::
function
<
bool
(
const
BriefNode
*
)
>
&
enter
,
const
std
::
function
<
bool
(
const
BriefNode
*
)
>
&
leave
)
{
typedef
struct
{
const
BriefNode
*
node
;
bool
leave
;
}
FNode
;
std
::
vector
<
FNode
>
stack
;
for
(
auto
&
node
:
source
)
{
stack
.
push_back
(
FNode
{
node
,
false
});
}
std
::
unordered_set
<
const
BriefNode
*>
visited
;
while
(
!
stack
.
empty
())
{
auto
fnode
=
stack
.
back
();
stack
.
pop_back
();
if
(
fnode
.
leave
)
{
if
(
leave
&&
!
leave
(
fnode
.
node
))
return
;
}
if
(
visited
.
count
(
fnode
.
node
))
continue
;
visited
.
insert
(
fnode
.
node
);
if
(
enter
&&
!
enter
(
fnode
.
node
))
return
;
if
(
leave
)
stack
.
push_back
(
FNode
{
fnode
.
node
,
true
});
const
std
::
vector
<
BriefNode
*>
iter_nodes
=
reverse
==
true
?
fnode
.
node
->
inlinks
:
fnode
.
node
->
outlinks
;
for
(
const
BriefNode
*
node
:
iter_nodes
)
{
if
(
!
visited
.
count
(
node
))
{
stack
.
push_back
(
FNode
{
node
,
false
});
}
}
}
}
std
::
vector
<
std
::
vector
<
Node
*>>
SubGraphSplitter
::
ExtractSubGraphs
()
{
std
::
vector
<
std
::
vector
<
Node
*>>
SubGraphSplitter
::
ExtractSubGraphs
()
{
// Run the Extract algorithm to find all subgraphs.
std
::
vector
<
Node
*>
marked_nodes
;
std
::
vector
<
Node
*>
marked_nodes
;
// We use brief_node_map to represent the original graph in order to avoid
// changing the original graph.
std
::
unordered_map
<
int
,
BriefNode
*>
brief_node_map
;
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
*
graph_
).
nodes_in_TS
())
{
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
*
graph_
).
nodes_in_TS
())
{
brief_node_map
[
node
.
id
()]
=
new
BriefNode
(
&
node
);
if
(
node
.
attr
(
kMarkerAttrName
).
Bool
())
{
if
(
node
.
attr
(
kMarkerAttrName
).
Bool
())
{
marked_nodes
.
push_back
(
&
node
);
marked_nodes
.
push_back
(
&
node
);
}
}
}
}
// extract sub-graphs in the marked node set, use Union Find algorithm.
// extract sub-graphs in the marked node set, use Union Find algorithm.
node_map_t
node_map
;
// id to ptr
node_map_t
node_map
;
// id to ptr
for
(
auto
*
n
:
marked_nodes
)
{
for
(
auto
*
n
:
marked_nodes
)
{
...
@@ -88,11 +201,73 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
...
@@ -88,11 +201,73 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
n
->
attr
(
kUnionFindParent
).
Int32
()
=
n
->
id
();
n
->
attr
(
kUnionFindParent
).
Int32
()
=
n
->
id
();
node_map
[
n
->
id
()]
=
n
;
node_map
[
n
->
id
()]
=
n
;
}
}
std
::
unordered_set
<
Node
*>
visited
;
for
(
auto
*
n
:
marked_nodes
)
{
// create breif node map
for
(
auto
*
out
:
n
->
outlinks
)
{
for
(
auto
&
itr
:
brief_node_map
)
{
if
(
node_map
.
count
(
out
->
id
()))
{
for
(
Node
*
node
:
itr
.
second
->
node
->
inlinks
)
{
UnionFindCombine
(
node_map
,
n
->
id
(),
out
->
id
());
itr
.
second
->
inlinks
.
push_back
(
brief_node_map
[
node
->
id
()]);
}
for
(
Node
*
node
:
itr
.
second
->
node
->
outlinks
)
{
itr
.
second
->
outlinks
.
push_back
(
brief_node_map
[
node
->
id
()]);
}
}
for
(
auto
&
itr
:
brief_node_map
)
{
BriefNode
*
brief_node
=
itr
.
second
;
if
(
!
brief_node
->
node
->
attr
(
kMarkerAttrName
).
Bool
())
{
VLOG
(
4
)
<<
brief_node
->
node
->
id
()
<<
" node not a trt candicate."
;
continue
;
}
// Our algorithm must guarantee that:
// 1. The graph is always directed acyclic graph(DAG).
// 2. If there is a path in the subgraph from X to Y (X and Y are both
// nodes
// in the subgraph), then all paths from X to Y are in the subgraph.
//
// In order to achieve the above guarantee.
// For adjacent nodes src -> dst.
// 1. Get all dst input nodes except src.
// 2. Reverse DFS from those input nodes
// 3. If there is a path from input nodes to src,
// then the src and dst nodes can not be fused into one node,
// otherwise it can be done.
while
(
true
)
{
std
::
unordered_set
<
BriefNode
*>
contract_nodes
;
for
(
auto
*
out
:
brief_node
->
outlinks
)
{
// must be an trt candidate
if
(
!
out
->
node
->
attr
(
kMarkerAttrName
).
Bool
())
continue
;
// get all dst input nodes except src.
std
::
vector
<
BriefNode
*>
source_nodes
;
for
(
auto
*
n
:
out
->
inlinks
)
{
if
(
n
!=
brief_node
)
{
source_nodes
.
push_back
(
n
);
}
}
// Reverse DFS from the source_nodes.
bool
have_excess_path
=
false
;
FlexibleDFS
(
source_nodes
,
true
,
nullptr
,
[
&
have_excess_path
,
brief_node
](
const
BriefNode
*
n
)
{
if
(
n
==
brief_node
)
{
have_excess_path
=
true
;
return
false
;
}
return
true
;
});
if
(
have_excess_path
)
continue
;
contract_nodes
.
insert
(
out
);
}
if
(
contract_nodes
.
empty
())
break
;
for
(
auto
dst_node
:
contract_nodes
)
{
UnionFindCombine
(
node_map
,
brief_node
->
node
->
id
(),
dst_node
->
node
->
id
());
UnionContractedNodes
(
brief_node_map
,
brief_node
->
node
->
id
(),
dst_node
->
node
->
id
());
}
}
}
}
}
}
...
@@ -128,6 +303,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
...
@@ -128,6 +303,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
auto
io
=
ExtractInputAndOutputOfSubGraph
(
subgraph
);
auto
io
=
ExtractInputAndOutputOfSubGraph
(
subgraph
);
block_node
->
inlinks
=
std
::
move
(
io
.
first
);
block_node
->
inlinks
=
std
::
move
(
io
.
first
);
block_node
->
outlinks
=
std
::
move
(
io
.
second
);
block_node
->
outlinks
=
std
::
move
(
io
.
second
);
for
(
auto
*
node
:
subgraph
)
{
for
(
auto
*
node
:
subgraph
)
{
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
// pass.
// pass.
...
...
paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
浏览文件 @
03ff4f68
...
@@ -82,7 +82,7 @@ TEST(SubGraphSplitter, Fuse) {
...
@@ -82,7 +82,7 @@ TEST(SubGraphSplitter, Fuse) {
// At least one nodes should be deleted.
// At least one nodes should be deleted.
ASSERT_EQ
(
dfg
.
nodes
.
size
(),
count0
+
1
);
// added a new FunctionBlock
ASSERT_EQ
(
dfg
.
nodes
.
size
(),
count0
+
1
);
// added a new FunctionBlock
ASSERT_EQ
(
6
,
count1
);
ASSERT_EQ
(
11
,
count1
);
}
}
}
// namespace analysis
}
// namespace analysis
...
...
paddle/fluid/operators/tensorrt_engine_op.h
浏览文件 @
03ff4f68
...
@@ -160,11 +160,21 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
...
@@ -160,11 +160,21 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
fluid_t
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
fluid_t
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
()).
device
)),
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
()).
device
)),
size
*
sizeof
(
float
));
size
*
sizeof
(
float
));
//} else {
// engine->GetOutputInGPU(
// TODO(zhaolong) : delete it sometimes
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
/* THIS CODE JUST FOR TEST
// size * sizeof(float));
std::cout << output_maps[output_index] << std::endl;
//}
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(framework::make_ddim(ddim));
auto* temp_data = temp_tensor.mutable_data<float>(cpu_place);
TensorCopySync(*fluid_t, cpu_place ,&temp_tensor);
for(int i = 0; i < size; i++) {
std::cout << temp_data[i] << " " ;
}
std::cout << std::endl;
*/
output_index
+=
1
;
output_index
+=
1
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录