Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b1401fb7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1401fb7
编写于
1月 07, 2020
作者:
Y
Yiqun Liu
提交者:
石晓伟
1月 07, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove subgraph_detector from inference/analysis to the common framework/ir directory. (#22094)
test=develop
上级
50bee83f
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
660 addition
and
673 deletion
+660
-673
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-1
paddle/fluid/framework/ir/ngraph_subgraph_pass.cc
paddle/fluid/framework/ir/ngraph_subgraph_pass.cc
+7
-11
paddle/fluid/framework/ir/subgraph_detector.cc
paddle/fluid/framework/ir/subgraph_detector.cc
+472
-474
paddle/fluid/framework/ir/subgraph_detector.h
paddle/fluid/framework/ir/subgraph_detector.h
+154
-160
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+0
-1
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
+5
-8
paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc
...luid/inference/analysis/ir_passes/anakin_subgraph_pass.cc
+9
-8
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+11
-10
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
b1401fb7
...
...
@@ -39,6 +39,7 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library
(
pass SRCS pass.cc DEPS graph node graph_helper
)
cc_library
(
graph_traits SRCS graph_traits.cc DEPS graph
)
cc_library
(
graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits
)
cc_library
(
subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor
)
cc_library
(
fuse_pass_base SRCS fuse_pass_base.cc DEPS pass
)
cc_library
(
placement_pass_base SRCS placement_pass_base.cc DEPS pass
)
...
...
@@ -99,7 +100,7 @@ endif()
if
(
WITH_NGRAPH
)
cc_library
(
ngraph_subgraph_pass SRCS ngraph_subgraph_pass.cc DEPS ngraph_bridge
analysis_helper subgraph_detector graph_pattern_detector pass
fuse_pass_base
${
op_library_DEPS
}
)
subgraph_detector
fuse_pass_base
${
op_library_DEPS
}
)
set
(
pass_file
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/api/paddle_inference_pass.h
)
file
(
APPEND
${
pass_file
}
"USE_PASS(ngraph_subgraph_pass);
\n
"
)
set
(
INFER_IR_PASSES
${
INFER_IR_PASSES
}
ngraph_subgraph_pass CACHE INTERNAL
""
)
...
...
paddle/fluid/framework/ir/ngraph_subgraph_pass.cc
浏览文件 @
b1401fb7
...
...
@@ -20,8 +20,7 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/ngraph_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
...
...
@@ -30,8 +29,6 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
namespace
ANAT
=
paddle
::
inference
::
analysis
;
std
::
string
GenerateEngineKey
(
const
std
::
set
<
std
::
string
>
&
engine_inputs
,
const
std
::
set
<
std
::
string
>
&
engine_outputs
,
const
std
::
string
&
size
)
{
...
...
@@ -59,19 +56,18 @@ void NgraphSubgraphPass::ApplyImpl(Graph *graph) const {
return
!
paddle
::
operators
::
NgraphBridge
::
isRegister
(
op_type
);
};
ANAT
::
SubGraphFuser
fuser
(
graph
,
teller
,
0
,
"ngraph_engine"
);
SubGraphFuser
fuser
(
graph
,
teller
,
0
,
"ngraph_engine"
);
fuser
();
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
!
A
NAT
::
A
gent
(
node
).
subgraph
()
->
empty
())
{
if
(
node
->
IsOp
()
&&
!
Agent
(
node
).
subgraph
()
->
empty
())
{
OpDesc
*
op_desc
=
node
->
Op
();
op_desc
->
SetType
(
"ngraph_engine"
);
CreateNgraphEngineOp
(
node
,
graph
);
std
::
unordered_set
<
const
Node
*>
nodes2remove
(
ANAT
::
Agent
(
node
).
subgraph
()
->
begin
(),
ANAT
::
Agent
(
node
).
subgraph
()
->
end
());
Agent
(
node
).
subgraph
()
->
begin
(),
Agent
(
node
).
subgraph
()
->
end
());
GraphSafeRemoveNodes
(
graph
,
nodes2remove
);
}
...
...
@@ -79,7 +75,7 @@ void NgraphSubgraphPass::ApplyImpl(Graph *graph) const {
std
::
unordered_set
<
const
Node
*>
nodes2remove
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
A
NAT
::
A
gent
(
node
).
deleted
())
{
if
(
node
->
IsOp
()
&&
Agent
(
node
).
deleted
())
{
nodes2remove
.
insert
(
node
);
}
}
...
...
@@ -116,7 +112,7 @@ void UpdateNgraphIO(Node *node, Graph *graph,
return
;
}
auto
&
subgraph
=
*
A
NAT
::
A
gent
(
node
).
subgraph
();
auto
&
subgraph
=
*
Agent
(
node
).
subgraph
();
std
::
unordered_set
<
std
::
string
>
inputs
;
std
::
unordered_set
<
std
::
string
>
outputs
;
for
(
auto
*
node
:
subgraph
)
{
...
...
@@ -138,7 +134,7 @@ void UpdateNgraphIO(Node *node, Graph *graph,
}
void
NgraphSubgraphPass
::
CreateNgraphEngineOp
(
Node
*
node
,
Graph
*
graph
)
const
{
auto
&
subgraph
=
*
A
NAT
::
A
gent
(
node
).
subgraph
();
auto
&
subgraph
=
*
Agent
(
node
).
subgraph
();
PADDLE_ENFORCE_NE
(
subgraph
.
empty
(),
true
,
"subgraph cannot be empty"
);
framework
::
proto
::
BlockDesc
block_proto
;
...
...
paddle/fluid/
inference/analysis/ir_passes
/subgraph_detector.cc
→
paddle/fluid/
framework/ir
/subgraph_detector.cc
浏览文件 @
b1401fb7
...
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/
inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/
framework/ir/subgraph_detector.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
...
...
@@ -24,10 +24,8 @@ limitations under the License. */
DECLARE_bool
(
use_ngraph
);
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
using
framework
::
ir
::
Node
;
namespace
framework
{
namespace
ir
{
std
::
pair
<
std
::
vector
<
Node
*>
,
std
::
vector
<
Node
*>>
ExtractInputAndOutputOfSubGraph
(
std
::
vector
<
Node
*>
&
graph
)
{
// NOLINT
...
...
@@ -469,6 +467,6 @@ inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
return
node
.
inputs
.
size
()
==
n
;
}
}
// namespace
analysis
}
// namespace
inference
}
// namespace
ir
}
// namespace
framework
}
// namespace paddle
paddle/fluid/
inference/analysis/ir_passes
/subgraph_detector.h
→
paddle/fluid/
framework/ir
/subgraph_detector.h
浏览文件 @
b1401fb7
...
...
@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file defines the the class to partition a graph.
*/
#pragma once
#include <string>
...
...
@@ -23,15 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
using
framework
::
ir
::
Graph
;
using
framework
::
ir
::
NodesTSIterator
;
namespace
framework
{
namespace
ir
{
const
char
kIsFunctionNode
[]
=
"__is_function_node__"
;
const
char
kFunctionNodeSubGraph
[]
=
"__function_node_sub_graph__"
;
...
...
@@ -45,13 +36,12 @@ const char kSubgraphSplitterMarkerAttrName[] =
class
SubgraphDetector
{
public:
// Tell whether a node is inside a sub-graph.
using
NodeInsideSubgraphTeller
=
std
::
function
<
bool
(
const
framework
::
ir
::
Node
*
)
>
;
using
NodeInsideSubgraphTeller
=
std
::
function
<
bool
(
const
Node
*
)
>
;
SubgraphDetector
(
Graph
*
graph
,
const
NodeInsideSubgraphTeller
&
teller
)
:
graph_
(
graph
),
node_inside_subgraph_teller_
(
teller
)
{}
std
::
vector
<
std
::
vector
<
framework
::
ir
::
Node
*>>
operator
()();
std
::
vector
<
std
::
vector
<
Node
*>>
operator
()();
protected:
// Mark the nodes inside the accepted sub-graph using
...
...
@@ -59,7 +49,7 @@ class SubgraphDetector {
void
MarkNodesInsideSubGraph
();
// Merge the marked nodes into sub-graphs and return the sub-graphs.
std
::
vector
<
std
::
vector
<
framework
::
ir
::
Node
*>>
ExtractSubGraphs
();
std
::
vector
<
std
::
vector
<
Node
*>>
ExtractSubGraphs
();
private:
Graph
*
graph_
;
...
...
@@ -99,14 +89,14 @@ struct NodeWrapper {
bool
deleted
{
false
};
bool
marked
{
false
};
int
union_find_parent
{
-
1
};
std
::
vector
<
framework
::
ir
::
Node
*>
subgraph
;
std
::
vector
<
Node
*>
subgraph
;
};
/*
* ir::Node agent for subgraph detector.
*/
struct
Agent
{
explicit
Agent
(
framework
::
ir
::
Node
*
x
)
:
x_
(
x
)
{}
explicit
Agent
(
Node
*
x
)
:
x_
(
x
)
{}
NodeWrapper
&
wrapper
()
{
if
(
!
x_
->
IsWrappedBy
<
NodeWrapper
>
())
{
...
...
@@ -128,17 +118,17 @@ struct Agent {
int
union_find_parent
()
{
return
wrapper
().
union_find_parent
;
}
void
set_union_find_parent
(
int
v
)
{
wrapper
().
union_find_parent
=
v
;
}
std
::
vector
<
framework
::
ir
::
Node
*>
*
subgraph
()
{
return
&
wrapper
().
subgraph
;
}
std
::
vector
<
framework
::
ir
::
Node
*>
&
inputs
()
{
return
x_
->
inputs
;
}
std
::
vector
<
framework
::
ir
::
Node
*>
&
outputs
()
{
return
x_
->
outputs
;
}
std
::
vector
<
Node
*>
*
subgraph
()
{
return
&
wrapper
().
subgraph
;
}
std
::
vector
<
Node
*>
&
inputs
()
{
return
x_
->
inputs
;
}
std
::
vector
<
Node
*>
&
outputs
()
{
return
x_
->
outputs
;
}
private:
framework
::
ir
::
Node
*
x_
;
Node
*
x_
;
};
// The nodes those have no input will be treated as start points.
static
std
::
vector
<
framework
::
ir
::
Node
*>
ExtractStartPoints
(
const
Graph
&
g
)
{
std
::
vector
<
framework
::
ir
::
Node
*>
result
;
static
std
::
vector
<
Node
*>
ExtractStartPoints
(
const
Graph
&
g
)
{
std
::
vector
<
Node
*>
result
;
for
(
auto
*
node
:
g
.
Nodes
())
{
if
(
node
->
inputs
.
empty
())
{
result
.
push_back
(
node
);
...
...
@@ -149,12 +139,16 @@ static std::vector<framework::ir::Node *> ExtractStartPoints(const Graph &g) {
static
iterator_range
<
NodesTSIterator
>
TopologicalSort
(
const
Graph
&
g
)
{
auto
start_points
=
ExtractStartPoints
(
g
);
PADDLE_ENFORCE
(
!
start_points
.
empty
());
PADDLE_ENFORCE_GT
(
start_points
.
size
(),
0U
,
platform
::
errors
::
InvalidArgument
(
"Expected the number of graph's start points >= 1. Expected %d."
,
start_points
.
size
()));
NodesTSIterator
x
(
start_points
);
return
iterator_range
<
NodesTSIterator
>
(
NodesTSIterator
(
start_points
),
NodesTSIterator
());
}
}
// namespace
analysis
}
// namespace
inference
}
// namespace
ir
}
// namespace
framework
}
// namespace paddle
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
b1401fb7
...
...
@@ -24,7 +24,6 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
...
...
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
浏览文件 @
b1401fb7
cc_library
(
subgraph_detector SRCS subgraph_detector.cc subgraph_util.cc DEPS proto_desc
)
if
(
WITH_TESTING
)
add_dependencies
(
subgraph_detector gtest
)
endif
()
cc_library
(
subgraph_util SRCS subgraph_util.cc DEPS subgraph_detector
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
cc_library
(
tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_
detector
tensorrt_op_teller
)
cc_library
(
tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_
util
tensorrt_op_teller
)
set
(
analysis_deps
${
analysis_deps
}
subgraph_
detector
tensorrt_subgraph_pass
subgraph_
util
tensorrt_subgraph_pass
CACHE INTERNAL
""
)
set
(
pass_file
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/api/paddle_inference_pass.h
)
...
...
@@ -16,10 +13,10 @@ if (WITH_GPU AND TENSORRT_FOUND)
endif
()
if
(
ANAKIN_SUBGRAPH
)
cc_library
(
anakin_subgraph_pass SRCS anakin_subgraph_pass.cc DEPS subgraph_
detector
anakin_op_teller
)
cc_library
(
anakin_subgraph_pass SRCS anakin_subgraph_pass.cc DEPS subgraph_
util
anakin_op_teller
)
set
(
analysis_deps
${
analysis_deps
}
subgraph_
detector
anakin_subgraph_pass
subgraph_
util
anakin_subgraph_pass
CACHE INTERNAL
""
)
set
(
pass_file
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/api/paddle_inference_pass.h
)
...
...
paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc
浏览文件 @
b1401fb7
...
...
@@ -22,11 +22,11 @@
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/op_teller.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
...
...
@@ -50,7 +50,7 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
return
anakin
::
OpTeller
::
Global
().
Tell
(
node
->
Op
()
->
Type
(),
*
node
->
Op
());
};
SubGraphFuser
fuser
(
graph
,
teller
,
6
/* min_subgraph_size */
);
framework
::
ir
::
SubGraphFuser
fuser
(
graph
,
teller
,
6
/* min_subgraph_size */
);
fuser
();
std
::
vector
<
std
::
string
>
graph_param_names
=
...
...
@@ -61,17 +61,18 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
std
::
vector
<
std
::
string
>
repetitive_params
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
!
Agent
(
node
).
subgraph
()
->
empty
())
{
if
(
node
->
IsOp
()
&&
!
framework
::
ir
::
Agent
(
node
).
subgraph
()
->
empty
())
{
CreateAnakinOp
(
node
,
graph
,
graph_param_names
,
&
repetitive_params
);
std
::
unordered_set
<
const
Node
*>
nodes2remove
(
Agent
(
node
).
subgraph
()
->
begin
(),
Agent
(
node
).
subgraph
()
->
end
());
framework
::
ir
::
Agent
(
node
).
subgraph
()
->
begin
(),
framework
::
ir
::
Agent
(
node
).
subgraph
()
->
end
());
framework
::
ir
::
GraphSafeRemoveNodes
(
graph
,
nodes2remove
);
}
}
std
::
unordered_set
<
const
Node
*>
nodes2remove
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
Agent
(
node
).
deleted
())
{
if
(
node
->
IsOp
()
&&
framework
::
ir
::
Agent
(
node
).
deleted
())
{
nodes2remove
.
insert
(
node
);
}
}
...
...
@@ -96,11 +97,11 @@ std::string GenerateAnakinEngineKey(const std::set<std::string> &engine_inputs,
}
void
AnakinSubgraphPass
::
CreateAnakinOp
(
framework
::
ir
::
Node
*
node
,
Graph
*
graph
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Graph
*
graph
,
const
std
::
vector
<
std
::
string
>
&
graph_params
,
std
::
vector
<
std
::
string
>
*
repetitive_params
)
const
{
auto
*
op_desc
=
node
->
Op
();
auto
&
subgraph
=
*
Agent
(
node
).
subgraph
();
auto
&
subgraph
=
*
framework
::
ir
::
Agent
(
node
).
subgraph
();
PADDLE_ENFORCE
(
!
subgraph
.
empty
());
framework
::
ProgramDesc
*
program_desc
=
...
...
@@ -164,7 +165,7 @@ void AnakinSubgraphPass::CreateAnakinOp(
graph_var_map
[
node
->
Name
()]
=
node
;
}
}
auto
&
subgraph_nodes
=
*
Agent
(
node
).
subgraph
();
auto
&
subgraph_nodes
=
*
framework
::
ir
::
Agent
(
node
).
subgraph
();
// The following procedure is used to rename all the intermediate
// variables and the output variables of the subgraph.
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
b1401fb7
...
...
@@ -17,8 +17,8 @@
#include <set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
...
...
@@ -40,8 +40,8 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
return
tensorrt
::
OpTeller
::
Global
().
Tell
(
node
->
Op
()
->
Type
(),
*
node
->
Op
());
};
SubGraphFuser
fuser
(
graph
,
teller
,
Get
<
int
>
(
"min_subgraph_size"
)
/*min subgraph size*/
,
framework
::
ir
::
SubGraphFuser
fuser
(
graph
,
teller
,
Get
<
int
>
(
"min_subgraph_size"
)
/*min subgraph size*/
,
"tensorrt_engine"
);
fuser
();
...
...
@@ -52,18 +52,19 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
std
::
vector
<
std
::
string
>
repetitive_params
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
!
Agent
(
node
).
subgraph
()
->
empty
())
{
if
(
node
->
IsOp
()
&&
!
framework
::
ir
::
Agent
(
node
).
subgraph
()
->
empty
())
{
CreateTensorRTOp
(
node
,
graph
,
graph_param_names
,
&
repetitive_params
);
std
::
unordered_set
<
const
Node
*>
nodes2remove
(
Agent
(
node
).
subgraph
()
->
begin
(),
Agent
(
node
).
subgraph
()
->
end
());
framework
::
ir
::
Agent
(
node
).
subgraph
()
->
begin
(),
framework
::
ir
::
Agent
(
node
).
subgraph
()
->
end
());
framework
::
ir
::
GraphSafeRemoveNodes
(
graph
,
nodes2remove
);
}
}
std
::
unordered_set
<
const
Node
*>
nodes2remove
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
Agent
(
node
).
deleted
())
{
if
(
node
->
IsOp
()
&&
framework
::
ir
::
Agent
(
node
).
deleted
())
{
nodes2remove
.
insert
(
node
);
}
}
...
...
@@ -88,11 +89,11 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
}
void
TensorRtSubgraphPass
::
CreateTensorRTOp
(
framework
::
ir
::
Node
*
node
,
Graph
*
graph
,
framework
::
ir
::
Node
*
node
,
framework
::
ir
::
Graph
*
graph
,
const
std
::
vector
<
std
::
string
>
&
graph_params
,
std
::
vector
<
std
::
string
>
*
repetitive_params
)
const
{
auto
*
op_desc
=
node
->
Op
();
auto
&
subgraph
=
*
Agent
(
node
).
subgraph
();
auto
&
subgraph
=
*
framework
::
ir
::
Agent
(
node
).
subgraph
();
PADDLE_ENFORCE
(
!
subgraph
.
empty
());
framework
::
ProgramDesc
*
program_desc
=
...
...
@@ -161,7 +162,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
if
(
precision_mode
==
AnalysisConfig
::
Precision
::
kHalf
)
enable_fp16
=
true
;
auto
enable_int8
=
Get
<
bool
>
(
"enable_int8"
);
auto
use_calib_mode
=
Get
<
bool
>
(
"use_calib_mode"
);
auto
&
subgraph_nodes
=
*
Agent
(
node
).
subgraph
();
auto
&
subgraph_nodes
=
*
framework
::
ir
::
Agent
(
node
).
subgraph
();
// The following procedure is used to rename all the intermediate
// variables and the output variables of the subgraph.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录