Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
68a07328
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
68a07328
编写于
1月 07, 2019
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into add_pyramid_dnn_support
test=develop
上级
00e4de04
317840d3
变更
56
隐藏空白更改
内联
并排
Showing
56 changed file
with
1533 addition
and
978 deletion
+1533
-978
cmake/configure.cmake
cmake/configure.cmake
+1
-0
cmake/external/ngraph.cmake
cmake/external/ngraph.cmake
+1
-1
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+39
-15
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+6
-2
paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
...fluid/framework/details/multi_devices_graph_check_pass.cc
+57
-47
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+475
-389
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+106
-38
paddle/fluid/framework/naive_executor.cc
paddle/fluid/framework/naive_executor.cc
+8
-8
paddle/fluid/framework/ngraph_operator.cc
paddle/fluid/framework/ngraph_operator.cc
+1
-1
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+0
-2
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+0
-10
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
+11
-7
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+5
-3
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc
...uid/inference/analysis/passes/ir_analysis_compose_pass.cc
+0
-23
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h
...luid/inference/analysis/passes/ir_analysis_compose_pass.h
+0
-2
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+154
-66
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+46
-37
paddle/fluid/inference/api/analysis_predictor_tester.cc
paddle/fluid/inference/api/analysis_predictor_tester.cc
+15
-15
paddle/fluid/inference/api/api_anakin_engine.h
paddle/fluid/inference/api/api_anakin_engine.h
+0
-2
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+1
-1
paddle/fluid/inference/api/api_impl_tester.cc
paddle/fluid/inference/api/api_impl_tester.cc
+2
-1
paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc
paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc
+4
-5
paddle/fluid/inference/api/demo_ci/vis_demo.cc
paddle/fluid/inference/api/demo_ci/vis_demo.cc
+6
-7
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+88
-21
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+2
-3
paddle/fluid/inference/api/paddle_pass_builder.h
paddle/fluid/inference/api/paddle_pass_builder.h
+11
-1
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+49
-0
paddle/fluid/inference/tensorrt/op_teller.h
paddle/fluid/inference/tensorrt/op_teller.h
+68
-0
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
+3
-6
paddle/fluid/inference/tests/api/analyzer_lac_tester.cc
paddle/fluid/inference/tests/api/analyzer_lac_tester.cc
+4
-5
paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc
paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc
+4
-5
paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
+5
-6
paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc
paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc
+4
-6
paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
+14
-14
paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
+4
-6
paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc
...le/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc
+4
-5
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
...le/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
+3
-6
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
...nference/tests/api/analyzer_text_classification_tester.cc
+4
-5
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
+5
-6
paddle/fluid/inference/tests/api/config_printer.h
paddle/fluid/inference/tests/api/config_printer.h
+10
-6
paddle/fluid/inference/tests/api/tester_helper.h
paddle/fluid/inference/tests/api/tester_helper.h
+4
-1
paddle/fluid/inference/tests/api/trt_models_tester.cc
paddle/fluid/inference/tests/api/trt_models_tester.cc
+12
-12
paddle/fluid/operators/conv_mkldnn_op.cc
paddle/fluid/operators/conv_mkldnn_op.cc
+46
-23
paddle/fluid/operators/linear_chain_crf_op.cc
paddle/fluid/operators/linear_chain_crf_op.cc
+2
-0
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+64
-70
paddle/fluid/operators/optimizers/adam_op.h
paddle/fluid/operators/optimizers/adam_op.h
+10
-3
paddle/fluid/platform/cuda_helper.h
paddle/fluid/platform/cuda_helper.h
+58
-0
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+13
-5
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+24
-52
paddle/fluid/platform/device_context_test.cu
paddle/fluid/platform/device_context_test.cu
+0
-3
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+5
-3
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-7
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+14
-0
python/paddle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py
...addle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py
+56
-14
python/paddle/fluid/tests/unittests/test_reader_reset.py
python/paddle/fluid/tests/unittests/test_reader_reset.py
+0
-2
未找到文件。
cmake/configure.cmake
浏览文件 @
68a07328
...
@@ -134,6 +134,7 @@ if(WITH_GPU)
...
@@ -134,6 +134,7 @@ if(WITH_GPU)
message
(
WARNING
"Anakin needs CUDNN >= 7.0 to compile. Force WITH_ANAKIN=OFF"
)
message
(
WARNING
"Anakin needs CUDNN >= 7.0 to compile. Force WITH_ANAKIN=OFF"
)
set
(
WITH_ANAKIN OFF CACHE STRING
"Anakin is valid only when CUDNN >= 7.0."
FORCE
)
set
(
WITH_ANAKIN OFF CACHE STRING
"Anakin is valid only when CUDNN >= 7.0."
FORCE
)
endif
()
endif
()
add_definitions
(
-DWITH_ANAKIN
)
endif
()
endif
()
if
(
WITH_ANAKIN
)
if
(
WITH_ANAKIN
)
# NOTICE(minqiyang): the end slash is important because $CUDNN_INCLUDE_DIR
# NOTICE(minqiyang): the end slash is important because $CUDNN_INCLUDE_DIR
...
...
cmake/external/ngraph.cmake
浏览文件 @
68a07328
...
@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs)
...
@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs)
INCLUDE
(
ExternalProject
)
INCLUDE
(
ExternalProject
)
SET
(
NGRAPH_PROJECT
"extern_ngraph"
)
SET
(
NGRAPH_PROJECT
"extern_ngraph"
)
SET
(
NGRAPH_GIT_TAG
"
v0.10.1
"
)
SET
(
NGRAPH_GIT_TAG
"
08851c2c45fcf9fa9c74871dd3dbc3fe38f37cc9
"
)
SET
(
NGRAPH_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/ngraph
)
SET
(
NGRAPH_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/ngraph
)
SET
(
NGRAPH_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/ngraph
)
SET
(
NGRAPH_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/ngraph
)
SET
(
NGRAPH_INC_DIR
${
NGRAPH_INSTALL_DIR
}
/include
)
SET
(
NGRAPH_INC_DIR
${
NGRAPH_INSTALL_DIR
}
/include
)
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
68a07328
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include <memory>
#include <memory>
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/details/multi_devices_graph_
check_
pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
...
@@ -86,10 +86,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -86,10 +86,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if
(
strategy
.
memory_optimize_
)
{
if
(
strategy
.
memory_optimize_
)
{
auto
analysis_var_pass
=
AppendPass
(
"analysis_var_pass"
);
auto
analysis_var_pass
=
AppendPass
(
"analysis_var_pass"
);
}
}
// Convert graph to run on multi-devices.
auto
multi_devices_pass
=
AppendPass
(
"multi_devices_pass"
);
AppendMultiDevPass
(
strategy
);
multi_devices_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy_
);
// Add a graph print pass to record a graph with device info.
// Add a graph print pass to record a graph with device info.
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
...
@@ -115,6 +113,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -115,6 +113,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
}
}
}
// Convert graph to run on multi-devices.
void
AppendMultiDevPass
(
const
BuildStrategy
&
strategy
)
{
ir
::
Pass
*
multi_devices_pass
;
if
(
strategy_
.
is_distribution_
)
{
multi_devices_pass
=
AppendPass
(
"dist_multi_devices_pass"
).
get
();
}
else
{
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
multi_devices_pass
=
AppendPass
(
"allreduce_mode_multi_devices_pass"
).
get
();
}
else
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
multi_devices_pass
=
AppendPass
(
"reduce_mode_multi_devices_pass"
).
get
();
}
else
{
PADDLE_THROW
(
"Unknown reduce strategy."
);
}
}
multi_devices_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy_
);
}
private:
private:
BuildStrategy
strategy_
;
BuildStrategy
strategy_
;
};
};
...
@@ -131,6 +148,10 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
...
@@ -131,6 +148,10 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
return
pass_builder_
;
return
pass_builder_
;
}
}
bool
BuildStrategy
::
IsMultiDevPass
(
const
std
::
string
&
pass_name
)
const
{
return
framework
::
details
::
MultiDevSSAGraphBuilder
().
count
(
pass_name
)
>
0
;
}
std
::
unique_ptr
<
ir
::
Graph
>
BuildStrategy
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
BuildStrategy
::
Apply
(
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
...
@@ -145,22 +166,23 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -145,22 +166,23 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
for
(
std
::
shared_ptr
<
ir
::
Pass
>
&
pass
:
pass_builder_
->
AllPasses
())
{
for
(
std
::
shared_ptr
<
ir
::
Pass
>
&
pass
:
pass_builder_
->
AllPasses
())
{
if
(
pass
->
Type
()
==
"multi_devices_pass"
)
{
if
(
IsMultiDevPass
(
pass
->
Type
())
)
{
pass
->
Erase
(
"places"
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
Erase
(
"loss_var_name"
);
pass
->
Erase
(
kLossVarName
);
pass
->
SetNotOwned
<
const
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name
);
pass
->
SetNotOwned
<
const
std
::
string
>
(
kLossVarName
,
&
loss_var_name
);
pass
->
Erase
(
"local_scopes"
);
pass
->
Erase
(
kLocalScopes
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
&
local_scopes
);
&
local_scopes
);
pass
->
Erase
(
"nranks"
);
pass
->
Erase
(
kNRanks
);
pass
->
Set
<
size_t
>
(
"nranks"
,
new
size_t
(
nranks
));
pass
->
Set
<
size_t
>
(
kNRanks
,
new
size_t
(
nranks
));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
"nccl_ctxs"
);
pass
->
Erase
(
"nccl_ctxs"
);
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
#endif
}
else
if
(
pass
->
Type
()
==
"analysis_var_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"analysis_var_pass"
)
{
const
std
::
vector
<
OpDesc
*>
*
all_op_descs
=
const
std
::
vector
<
OpDesc
*>
*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
...
@@ -201,7 +223,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -201,7 +223,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
USE_PASS
(
fuse_elewise_add_act_pass
);
USE_PASS
(
fuse_elewise_add_act_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
multi_batch_merge_pass
);
USE_PASS
(
multi_batch_merge_pass
);
USE_PASS
(
multi_devices_pass
);
USE_PASS
(
reduce_mode_multi_devices_pass
);
USE_PASS
(
allreduce_mode_multi_devices_pass
);
USE_PASS
(
dist_multi_devices_pass
);
USE_PASS
(
multi_devices_check_pass
);
USE_PASS
(
multi_devices_check_pass
);
USE_PASS
(
multi_devices_print_pass
);
USE_PASS
(
multi_devices_print_pass
);
USE_PASS
(
analysis_var_pass
);
USE_PASS
(
analysis_var_pass
);
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
68a07328
...
@@ -74,8 +74,6 @@ struct BuildStrategy {
...
@@ -74,8 +74,6 @@ struct BuildStrategy {
bool
fuse_elewise_add_act_ops_
{
false
};
bool
fuse_elewise_add_act_ops_
{
false
};
bool
enable_data_balance_
{
false
};
bool
memory_optimize_
{
false
};
bool
memory_optimize_
{
false
};
bool
memory_early_delete_
{
false
};
bool
memory_early_delete_
{
false
};
...
@@ -84,6 +82,10 @@ struct BuildStrategy {
...
@@ -84,6 +82,10 @@ struct BuildStrategy {
bool
fuse_broadcast_op_
{
false
};
bool
fuse_broadcast_op_
{
false
};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
bool
is_distribution_
{
false
};
int
num_trainers_
{
1
};
int
num_trainers_
{
1
};
int
trainer_id_
{
0
};
int
trainer_id_
{
0
};
std
::
vector
<
std
::
string
>
trainers_endpoints_
;
std
::
vector
<
std
::
string
>
trainers_endpoints_
;
...
@@ -104,6 +106,8 @@ struct BuildStrategy {
...
@@ -104,6 +106,8 @@ struct BuildStrategy {
bool
IsFinalized
()
const
{
return
is_finalized_
;
}
bool
IsFinalized
()
const
{
return
is_finalized_
;
}
bool
IsMultiDevPass
(
const
std
::
string
&
pass_name
)
const
;
// Apply the passes built by the pass_builder_. The passes will be
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
// applied to the Program and output an ir::Graph.
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
const
ProgramDesc
&
main_program
,
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
const
ProgramDesc
&
main_program
,
...
...
paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
浏览文件 @
68a07328
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
...
@@ -21,68 +21,78 @@ namespace paddle {
...
@@ -21,68 +21,78 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
bool
SSAGraghBuilderWithChecker
::
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
{
class
SSAGraghBuilderWithChecker
:
public
ir
::
Pass
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
protected:
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unordered_set
<
VarHandleBase
*>
ready_vars
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
std
::
unordered_set
<
OpHandleBase
*>
ready_ops
;
PADDLE_ENFORCE
(
IsValidGraph
(
graph
.
get
()));
return
graph
;
}
auto
insert_pending_var
=
[
&
](
VarHandleBase
*
var
)
{
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
{
pending_vars
.
insert
(
var
);
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
if
(
var
->
GeneratedOp
()
==
nullptr
)
{
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
ready_vars
.
emplace
(
var
);
std
::
unordered_set
<
VarHandleBase
*>
ready_vars
;
}
std
::
unordered_set
<
OpHandleBase
*>
ready_ops
;
};
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
)
)
{
auto
insert_pending_var
=
[
&
](
VarHandleBase
*
var
)
{
for
(
auto
&
name_pair
:
var_map
)
{
pending_vars
.
insert
(
var
);
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
if
(
var
->
GeneratedOp
()
==
nullptr
)
{
insert_pending_var
(
version_pai
r
);
ready_vars
.
emplace
(
va
r
);
}
}
}
};
}
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
insert_pending_var
(
var
);
for
(
auto
&
name_pair
:
var_map
)
{
}
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
insert_pending_var
(
version_pair
);
}
}
}
for
(
OpHandleBase
*
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
))
{
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
if
(
op
->
Inputs
().
empty
())
{
insert_pending_var
(
var
);
ready_ops
.
insert
(
op
);
}
else
{
pending_ops
.
insert
({
op
,
op
->
NoDupInputSize
()});
}
}
}
auto
run_all_ops
=
[
&
](
std
::
unordered_set
<
OpHandleBase
*>
&
set
)
{
for
(
OpHandleBase
*
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
))
{
for
(
auto
*
op
:
set
)
{
if
(
op
->
Inputs
().
empty
())
{
for
(
auto
out
:
op
->
Outputs
())
{
ready_ops
.
insert
(
op
);
ready_vars
.
emplace
(
out
);
}
else
{
pending_ops
.
insert
({
op
,
op
->
NoDupInputSize
()});
}
}
}
}
set
.
clear
();
};
while
(
!
pending_vars
.
empty
())
{
auto
run_all_ops
=
[
&
](
std
::
unordered_set
<
OpHandleBase
*>
&
set
)
{
run_all_ops
(
ready_ops
);
for
(
auto
*
op
:
set
)
{
for
(
auto
out
:
op
->
Outputs
())
{
ready_vars
.
emplace
(
out
);
}
}
set
.
clear
();
};
if
(
ready_vars
.
empty
())
{
while
(
!
pending_vars
.
empty
())
{
return
false
;
run_all_ops
(
ready_ops
);
}
for
(
auto
ready_var
:
ready_vars
)
{
if
(
ready_vars
.
empty
())
{
pending_vars
.
erase
(
ready_var
);
return
false
;
for
(
auto
*
op
:
ready_var
->
PendingOps
())
{
}
auto
&
deps
=
--
pending_ops
[
op
];
if
(
deps
==
0
)
{
for
(
auto
ready_var
:
ready_vars
)
{
ready_ops
.
insert
(
op
);
pending_vars
.
erase
(
ready_var
);
for
(
auto
*
op
:
ready_var
->
PendingOps
())
{
auto
&
deps
=
--
pending_ops
[
op
];
if
(
deps
==
0
)
{
ready_ops
.
insert
(
op
);
}
}
}
}
}
ready_vars
.
clear
();
}
}
re
ady_vars
.
clear
()
;
re
turn
true
;
}
}
return
true
;
}
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
68a07328
...
@@ -134,15 +134,8 @@ void AddOutputToLeafOps(ir::Graph *graph) {
...
@@ -134,15 +134,8 @@ void AddOutputToLeafOps(ir::Graph *graph) {
}
}
}
// namespace
}
// namespace
static
const
char
kLossVarName
[]
=
"loss_var_name"
;
void
MultiDevSSAGraphBuilderBase
::
Init
()
const
{
static
const
char
kPlaces
[]
=
"places"
;
static
const
char
kLocalScopes
[]
=
"local_scopes"
;
static
const
char
kStrategy
[]
=
"strategy"
;
static
const
char
kNRanks
[]
=
"nranks"
;
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
all_vars_
.
clear
();
all_vars_
.
clear
();
balance_vars_
.
clear
();
loss_var_name_
=
Get
<
const
std
::
string
>
(
kLossVarName
);
loss_var_name_
=
Get
<
const
std
::
string
>
(
kLossVarName
);
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
...
@@ -151,31 +144,16 @@ void MultiDevSSAGraphBuilder::Init() const {
...
@@ -151,31 +144,16 @@ void MultiDevSSAGraphBuilder::Init() const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
);
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
);
#endif
#endif
balance_vars_
.
resize
(
places_
.
size
(),
0
);
if
(
strategy_
.
enable_data_balance_
&&
places_
.
size
()
==
1
)
{
LOG
(
WARNING
)
<<
"It is no need to enable data balance when there is only "
"one place. enable_data_balance is set to False."
;
strategy_
.
enable_data_balance_
=
false
;
}
}
}
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
Base
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
Init
();
Init
();
// Give the topology sort order and rebuild the graph structure.
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOperations
(
*
graph
);
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
ir
::
TopologySortOperations
(
*
graph
);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
sorted_ops
=
SortForReduceMode
(
sorted_ops
);
}
auto
nodes
=
graph
->
ReleaseNodes
();
auto
nodes
=
graph
->
ReleaseNodes
();
ir
::
Graph
&
result
=
*
graph
;
ir
::
Graph
&
result
=
*
graph
;
size_t
nranks
=
Get
<
size_t
>
(
kNRanks
);
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
IsVar
()
&&
node
->
Var
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
())
{
all_vars_
.
emplace
(
node
->
Name
(),
node
->
Var
());
all_vars_
.
emplace
(
node
->
Name
(),
node
->
Var
());
...
@@ -187,146 +165,61 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -187,146 +165,61 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
bcast_var_name_set
.
resize
(
places_
.
size
());
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
bool
is_dist_train
=
false
;
bool
insert_collection_ops
=
NeedCollectiveOps
();
std
::
unordered_map
<
std
::
string
,
int
>
sharded_var_device
;
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
if
(
OpHaveRole
(
*
node
,
OpRole
::
kRPC
))
{
if
(
DealWithSpecialOp
(
&
result
,
node
))
{
int
op_dev_id
=
CreateRPCOp
(
&
result
,
node
,
&
sharded_var_device
);
continue
;
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"Can not schedule the RPC operator to the right place."
);
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
auto
recv_vars_attr
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE
(
recv_vars_attr
.
size
()
==
2UL
);
// [parameter, gradient]
if
(
recv_vars_attr
[
0
].
find
(
".block"
)
==
std
::
string
::
npos
)
{
bcast_var_name_set
[
op_dev_id
].
emplace
(
recv_vars_attr
[
0
]);
}
}
is_dist_train
=
true
;
}
else
if
(
OpHaveRole
(
*
node
,
OpRole
::
kDist
))
{
int
op_dev_id
=
CreateDistTrainOp
(
&
result
,
node
,
&
sharded_var_device
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
bcast_var_name_set
[
op_dev_id
].
emplace
(
origin_param_name
);
}
}
else
if
(
IsScaleLossOp
(
node
))
{
// user can customize loss@grad if not use_default_grad_scale_
if
(
strategy_
.
gradient_scale_
!=
BuildStrategy
::
GradientScaleStrategy
::
kCustomized
)
{
// TODO(paddle-dev): Why is there no input for this op_handle?
auto
loss_grad_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
auto
out_dtype
=
all_vars_
.
at
(
loss_grad_name
)
->
GetDataType
();
CreateScaleLossGradOp
(
&
result
,
loss_grad_name
,
node
->
outputs
[
0
],
out_dtype
);
}
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
// It also assumes backward op will always follow the forward op in
// the block.
is_forwarding
=
false
;
}
else
{
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
node
,
sharded_var_device
);
// This op runs on all devices
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
if
(
IsScaleLossOp
(
node
))
{
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
// user can customize loss@grad if not use_default_grad_scale_
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
InsertScaleLossGradOp
(
&
result
,
node
);
sharded_var_device
.
emplace
(
n
->
Name
(),
op_dev_id
);
// This assumes the backward generating code will ensure IsScaleLossOp
}
// is true only for the op that scale the final scalar loss.
// It also assumes backward op will always follow the forward op in
// the block.
is_forwarding
=
false
;
}
else
{
}
else
{
// This op runs on all devices, and its output may have parameter's
CreateComputationalOps
(
&
result
,
node
,
places_
.
size
());
// gradients.
}
// TODO(paddle-dev): Why is so special about "read" op?
if
(
node
->
Op
()
->
Type
()
==
"read"
&&
strategy_
.
enable_data_balance_
)
{
node
->
Op
()
->
SetAttr
(
"throw_eof_exp"
,
false
);
CreateComputationalOps
(
&
result
,
node
,
places_
.
size
());
const
auto
&
data_var_names
=
node
->
Op
()
->
Output
(
"Out"
);
InsertDataBalanceOp
(
&
result
,
data_var_names
);
}
else
{
CreateComputationalOps
(
&
result
,
node
,
places_
.
size
());
}
if
(
!
is_forwarding
&&
nranks
>
1UL
)
{
// Insert collection ops
if
(
!
is_forwarding
&&
insert_collection_ops
)
{
try
{
bool
is_bk_op
=
bool
is_bk_op
=
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kBackward
));
static_cast
<
int
>
(
OpRole
::
kBackward
));
if
(
!
is_bk_op
)
continue
;
if
(
!
is_bk_op
)
continue
;
// Currently, we assume that once gradient is generated, it can be
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
// broadcast, and each gradient is only broadcast once.
try
{
auto
backward_vars
=
auto
backward_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
backward_vars
.
size
()
%
2
,
0
);
PADDLE_ENFORCE_EQ
(
backward_vars
.
size
()
%
2
,
0
);
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
auto
&
p_name
=
backward_vars
[
i
];
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
auto
&
g_name
=
backward_vars
[
i
+
1
];
auto
&
p_name
=
backward_vars
[
i
];
VLOG
(
10
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
auto
&
g_name
=
backward_vars
[
i
+
1
];
VLOG
(
10
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
InsertCollectiveOp
(
&
result
,
p_name
,
g_name
);
size_t
cur_device_id
=
-
1
;
switch
(
strategy_
.
reduce_
)
{
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
sharded_var_device
.
emplace
(
g_name
,
cur_device_id
);
if
(
!
is_dist_train
)
{
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
}
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
if
(
IsSparseGradient
(
g_name
))
{
CreateReduceOp
(
&
result
,
g_name
,
0
);
CreateBroadcastOp
(
&
result
,
g_name
,
0
);
}
else
{
InsertAllReduceOp
(
&
result
,
g_name
);
}
break
;
default:
LOG
(
FATAL
)
<<
"Unknown reduce strategy "
;
break
;
}
}
}
catch
(
boost
::
bad_get
e
)
{
}
}
}
catch
(
boost
::
bad_get
e
)
{
}
}
}
}
}
}
}
}
bool
use_gpu
=
false
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
use_gpu
=
nccl_ctxs_
!=
nullptr
;
#endif
// Insert broadcast operators principle:
InsertPostprocessOps
(
&
result
);
// 1. Broadcast optimized parameters in Reduce strategy;
// 2. No need broadcast optimized parameters in AllReduce strategy because of
// the optimization sub-graph would be run on every GPU;
// 3. Allways broadcast received parameters in Distribute Training.
if
((
use_gpu
&&
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
||
is_dist_train
)
{
if
(
strategy_
.
fuse_broadcast_op_
)
{
CreateFusedBroadcastOp
(
&
result
,
bcast_var_name_set
);
}
else
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set
.
size
();
++
dev_id
)
{
auto
&
to_bcast_set
=
bcast_var_name_set
[
dev_id
];
for
(
auto
&
bcast_name
:
to_bcast_set
)
{
CreateBroadcastOp
(
&
result
,
bcast_name
,
dev_id
);
}
}
}
}
/*
/*
Dependency graph has been constructed. However, there are still data
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
hazards need to be handled.
*/
*/
PolishGraphToSupportDataHazards
(
&
result
);
PolishGraphToSupportDataHazards
(
&
result
);
/*
/*
...
@@ -337,67 +230,54 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -337,67 +230,54 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
return
graph
;
return
graph
;
}
}
std
::
vector
<
ir
::
Node
*>
MultiDevSSAGraphBuilder
::
SortForReduceMode
(
void
MultiDevSSAGraphBuilderBase
::
InsertScaleLossGradOp
(
const
std
::
vector
<
ir
::
Node
*>
&
topo_ops
)
const
{
ir
::
Graph
*
result
,
const
ir
::
Node
*
node
)
const
{
std
::
unordered_map
<
std
::
string
,
int
>
sharded_var_device
;
// user can customize loss@grad if not use_default_grad_scale_
std
::
vector
<
ir
::
Node
*>
sorted_ops
;
size_t
loss_scale
=
0
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
delayed_op
;
switch
(
this
->
strategy_
.
gradient_scale_
)
{
sorted_ops
.
reserve
(
topo_ops
.
size
());
case
BuildStrategy
::
GradientScaleStrategy
::
kOne
:
loss_scale
=
1
;
auto
insert_delayed_op
=
[
&
](
const
std
::
string
&
var_name
,
int
dev_id
)
{
break
;
sharded_var_device
.
emplace
(
var_name
,
dev_id
);
case
BuildStrategy
::
GradientScaleStrategy
::
kCoeffNumDevice
:
if
(
delayed_op
.
count
(
var_name
))
{
loss_scale
=
Get
<
size_t
>
(
kNRanks
);
auto
&
ops
=
delayed_op
.
at
(
var_name
);
break
;
sorted_ops
.
insert
(
sorted_ops
.
end
(),
ops
.
begin
(),
ops
.
end
());
case
BuildStrategy
::
GradientScaleStrategy
::
kCustomized
:
delayed_op
.
at
(
var_name
).
clear
();
loss_scale
=
0
;
}
break
;
};
default:
LOG
(
FATAL
)
<<
"Unknown gradient scale strategy."
;
break
;
}
if
(
loss_scale
)
{
// TODO(paddle-dev): Why is there no input for this op_handle?
auto
loss_grad_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
auto
out_dtype
=
this
->
all_vars_
.
at
(
loss_grad_name
)
->
GetDataType
();
this
->
CreateScaleLossGradOp
(
result
,
loss_grad_name
,
node
->
outputs
[
0
],
loss_scale
,
out_dtype
);
}
}
for
(
ir
::
Node
*
node
:
topo_ops
)
{
std
::
vector
<
ir
::
Node
*>
MultiDevSSAGraphBuilderBase
::
SortOperations
(
int
op_dev_id
=
GetOpDeviceID
(
node
,
sharded_var_device
,
&
delayed_op
);
const
ir
::
Graph
&
graph
)
const
{
if
(
op_dev_id
>
-
1
)
{
return
ir
::
TopologySortOperations
(
graph
);
// This op only runs on one specific device.
}
sorted_ops
.
emplace_back
(
node
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
insert_delayed_op
(
n
->
Name
(),
op_dev_id
);
}
}
else
if
(
op_dev_id
==
-
1
)
{
// This op runs on all devices, and its output may have parameter's
// gradients.
sorted_ops
.
emplace_back
(
node
);
bool
is_bk_op
=
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kBackward
));
if
(
!
is_bk_op
)
continue
;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
std
::
vector
<
std
::
string
>
backward_vars
;
try
{
backward_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
}
catch
(
boost
::
bad_get
e
)
{
}
PADDLE_ENFORCE_EQ
(
backward_vars
.
size
()
%
2
,
0
);
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
bool
MultiDevSSAGraphBuilderBase
::
UseGPU
()
const
{
auto
&
g_name
=
backward_vars
[
i
+
1
];
bool
use_gpu
=
false
;
size_t
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
insert_delayed_op
(
g_name
,
static_cast
<
int
>
(
cur_device_id
));
use_gpu
=
nccl_ctxs_
!=
nullptr
;
}
#endif
}
else
if
(
op_dev_id
==
-
2
)
{
return
use_gpu
;
// The Op on which the Op depends has not yet been generated.
}
}
}
PADDLE_ENFORCE_EQ
(
sorted_ops
.
size
(),
topo_ops
.
size
());
bool
MultiDevSSAGraphBuilderBase
::
NeedCollectiveOps
()
const
{
return
sorted_ops
;
return
Get
<
size_t
>
(
kNRanks
)
>
1
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilder
Base
::
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
op_handle
->
SetDeviceContext
(
p
,
op_handle
->
SetDeviceContext
(
p
,
...
@@ -420,28 +300,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
...
@@ -420,28 +300,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
}
}
}
}
size_t
MultiDevSSAGraphBuilder
::
GetAppropriateDeviceID
(
void
MultiDevSSAGraphBuilderBase
::
SetCommunicationContext
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
{
int64_t
numel_sum
=
0
;
for
(
auto
var_name
:
var_names
)
{
if
(
all_vars_
.
find
(
var_name
)
==
all_vars_
.
end
())
continue
;
auto
var_desc
=
all_vars_
.
at
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
auto
dim
=
framework
::
make_ddim
(
var_desc
->
GetShape
());
int64_t
numel
=
framework
::
product
(
dim
);
PADDLE_ENFORCE_GT
(
numel
,
0
);
numel_sum
+=
numel
;
}
auto
smallest
=
std
::
min_element
(
std
::
begin
(
balance_vars_
),
std
::
end
(
balance_vars_
));
size_t
dev_id
=
static_cast
<
size_t
>
(
std
::
distance
(
std
::
begin
(
balance_vars_
),
smallest
));
balance_vars_
[
dev_id
]
+=
numel_sum
;
return
dev_id
;
}
void
MultiDevSSAGraphBuilder
::
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
const
platform
::
Place
&
p
)
const
{
OpHandleBase
*
op_handle
,
const
platform
::
Place
&
p
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if
(
nccl_ctxs_
==
nullptr
)
{
if
(
nccl_ctxs_
==
nullptr
)
{
...
@@ -454,9 +313,9 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
...
@@ -454,9 +313,9 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif
#endif
}
}
void
MultiDevSSAGraphBuilder
::
CreateBroadcastOp
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilder
Base
::
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
{
size_t
src_dev_id
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
op_handle
=
new
BroadcastOpHandle
(
auto
*
op_handle
=
new
BroadcastOpHandle
(
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
...
@@ -484,7 +343,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
...
@@ -484,7 +343,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateFusedBroadcastOp
(
void
MultiDevSSAGraphBuilder
Base
::
CreateFusedBroadcastOp
(
ir
::
Graph
*
result
,
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
bcast_varnames
)
const
{
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
bcast_varnames
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
...
@@ -522,17 +381,17 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
...
@@ -522,17 +381,17 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilder
Base
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
int
dev_id
)
const
{
int
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
[
dev_id
],
places_
[
dev_id
],
dev_id
));
local_scopes_
[
dev_id
],
places_
[
dev_id
],
dev_id
));
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
}
}
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilder
Base
::
CreateAllReduceOp
(
const
std
::
string
&
og
)
const
{
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
@@ -560,102 +419,15 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
...
@@ -560,102 +419,15 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
}
}
}
}
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
void
MultiDevSSAGraphBuilderBase
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
());
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
d_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
d_name
,
p
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
}
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
*
delay_ops
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
}
if
(
!
OpHaveRole
(
*
node
,
framework
::
OpRole
::
kOptimize
))
{
return
-
1
;
}
auto
param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
],
sharded_var_device
);
if
(
dev_id
==
-
1
)
{
(
*
delay_ops
)[
param_grad
[
1
]].
push_back
(
node
);
return
-
2
;
}
return
dev_id
;
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
}
if
(
!
OpHaveRole
(
*
node
,
framework
::
OpRole
::
kOptimize
))
{
return
-
1
;
}
auto
param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
],
sharded_var_device
);
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
return
dev_id
;
}
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
std
::
string
&
varname
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
{
auto
got
=
sharded_var_device
.
find
(
varname
);
if
(
got
==
sharded_var_device
.
end
())
{
auto
pos
=
varname
.
find
(
framework
::
kNewGradSuffix
);
if
(
pos
!=
std
::
string
::
npos
)
{
got
=
sharded_var_device
.
find
(
varname
.
substr
(
0
,
pos
));
}
}
return
got
==
sharded_var_device
.
end
()
?
-
1
:
got
->
second
;
}
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
loss_grad_name
,
ir
::
Graph
*
result
,
const
std
::
string
&
loss_grad_name
,
ir
::
Node
*
out_var_node
,
proto
::
VarType
::
Type
dtype
)
const
{
ir
::
Node
*
out_var_node
,
size_t
loss_scale
,
size_t
nranks
=
Get
<
size_t
>
(
"nranks"
);
proto
::
VarType
::
Type
dtype
)
const
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
// Insert ScaleCost OpHandle
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
nranks
,
local_scopes_
[
i
],
places_
[
i
],
dev_ctx
,
dtype
);
loss_scale
,
local_scopes_
[
i
],
places_
[
i
],
dev_ctx
,
dtype
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// FIXME: Currently ScaleLossGradOp only use device_count as scale
...
@@ -669,9 +441,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
...
@@ -669,9 +441,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateComputationalOps
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilderBase
::
CreateComputationalOps
(
ir
::
Node
*
node
,
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
num_places
)
const
{
size_t
num_places
)
const
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
...
@@ -681,9 +452,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
...
@@ -681,9 +452,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
}
}
}
}
VarHandle
*
MultiDevSSAGraphBuilder
::
CreateReduceOp
(
ir
::
Graph
*
result
,
VarHandle
*
MultiDevSSAGraphBuilder
Base
::
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
int
dst_dev_id
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
@@ -712,51 +483,273 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -712,51 +483,273 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return
var
;
return
var
;
}
}
int
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
bool
MultiDevSSAGraphBuilderBase
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
return
boost
::
get
<
int
>
(
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
{
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
int
op_dev_id
=
-
1
;
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
std
::
vector
<
std
::
string
>
input_var_names
;
static_cast
<
int
>
(
OpRole
::
kLoss
))
&&
std
::
vector
<
std
::
string
>
output_var_names
;
!
loss_var_name_
.
empty
();
// If loss_var is empty. This is test mode
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
}
input_var_names
.
push_back
(
input
->
Name
());
bool
MultiDevSSAGraphBuilderBase
::
IsSparseGradient
(
const
std
::
string
&
og
)
const
{
PADDLE_ENFORCE
(
all_vars_
.
count
(
og
)
!=
0
);
if
(
all_vars_
.
at
(
og
)
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
)
{
return
true
;
}
}
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
return
false
;
output_var_names
.
push_back
(
output
->
Name
());
}
void
AllReduceSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{
if
(
IsSparseGradient
(
g_name
))
{
CreateReduceOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
}
else
{
CreateAllReduceOp
(
result
,
g_name
);
}
}
}
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
int
BalanceVarSSAGraphBuilder
::
GetVarDeviceID
(
node
->
Op
()
->
Type
()
==
"split_selected_rows"
||
const
std
::
string
&
varname
)
const
{
node
->
Op
()
->
Type
()
==
"split_ids"
)
{
auto
got
=
sharded_var_device_
.
find
(
varname
);
// TODO(paddle-dev): getting the first var is not safe.
if
(
got
==
sharded_var_device_
.
end
())
{
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
],
*
sharded_var_device
);
auto
pos
=
varname
.
find
(
framework
::
kNewGradSuffix
);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
if
(
pos
!=
std
::
string
::
npos
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
got
=
sharded_var_device_
.
find
(
varname
.
substr
(
0
,
pos
));
for
(
auto
&
varname
:
input_var_names
)
{
}
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
}
return
got
==
sharded_var_device_
.
end
()
?
-
1
:
got
->
second
;
}
int
BalanceVarSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
}
if
(
!
OpHaveRole
(
*
node
,
framework
::
OpRole
::
kOptimize
))
{
return
-
1
;
}
auto
param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
]);
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
return
dev_id
;
}
size_t
BalanceVarSSAGraphBuilder
::
GetAppropriateDeviceID
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
{
int64_t
numel_sum
=
0
;
for
(
auto
var_name
:
var_names
)
{
if
(
all_vars_
.
find
(
var_name
)
==
all_vars_
.
end
())
continue
;
auto
var_desc
=
all_vars_
.
at
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
auto
dim
=
framework
::
make_ddim
(
var_desc
->
GetShape
());
int64_t
numel
=
framework
::
product
(
dim
);
PADDLE_ENFORCE_GT
(
numel
,
0
);
numel_sum
+=
numel
;
}
auto
smallest
=
std
::
min_element
(
std
::
begin
(
balance_vars_
),
std
::
end
(
balance_vars_
));
size_t
dev_id
=
static_cast
<
size_t
>
(
std
::
distance
(
std
::
begin
(
balance_vars_
),
smallest
));
balance_vars_
[
dev_id
]
+=
numel_sum
;
return
dev_id
;
}
void
BalanceVarSSAGraphBuilder
::
ResetState
()
const
{
balance_vars_
.
clear
();
sharded_var_device_
.
clear
();
balance_vars_
.
resize
(
places_
.
size
(),
0
);
}
void
ReduceSSAGraphBuilder
::
Init
()
const
{
MultiDevSSAGraphBuilderBase
::
Init
();
ResetState
();
}
void
ReduceSSAGraphBuilder
::
ResetState
()
const
{
BalanceVarSSAGraphBuilder
::
ResetState
();
bcast_var_name_set_
.
clear
();
bcast_var_name_set_
.
resize
(
places_
.
size
());
}
void
ReduceSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{
size_t
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
result
,
g_name
,
cur_device_id
);
sharded_var_device_
.
emplace
(
g_name
,
cur_device_id
);
bcast_var_name_set_
[
cur_device_id
].
emplace
(
p_name
);
}
bool
ReduceSSAGraphBuilder
::
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
BalanceVarSSAGraphBuilder
::
GetOpDeviceID
(
node
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
sharded_var_device_
.
emplace
(
n
->
Name
(),
op_dev_id
);
}
return
true
;
}
return
false
;
}
void
ReduceSSAGraphBuilder
::
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
{
if
(
UseGPU
())
{
if
(
strategy_
.
fuse_broadcast_op_
)
{
CreateFusedBroadcastOp
(
result
,
bcast_var_name_set_
);
}
else
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set_
.
size
();
++
dev_id
)
{
auto
&
to_bcast_set
=
bcast_var_name_set_
[
dev_id
];
for
(
auto
&
bcast_name
:
to_bcast_set
)
{
CreateBroadcastOp
(
result
,
bcast_name
,
dev_id
);
}
}
}
}
}
for
(
auto
&
varname
:
output_var_names
)
{
}
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
}
int
ReduceSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
*
delay_ops
)
const
{
if
(
!
OpHaveRole
(
*
node
,
framework
::
OpRole
::
kOptimize
))
{
return
-
1
;
}
auto
param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
]);
if
(
dev_id
==
-
1
)
{
(
*
delay_ops
)[
param_grad
[
1
]].
push_back
(
node
);
return
-
2
;
}
return
dev_id
;
}
std
::
vector
<
ir
::
Node
*>
ReduceSSAGraphBuilder
::
SortOperations
(
const
ir
::
Graph
&
graph
)
const
{
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
ir
::
TopologySortOperations
(
graph
);
return
SortForReduceMode
(
sorted_ops
);
}
std
::
vector
<
ir
::
Node
*>
ReduceSSAGraphBuilder
::
SortForReduceMode
(
const
std
::
vector
<
ir
::
Node
*>
&
topo_ops
)
const
{
std
::
vector
<
ir
::
Node
*>
sorted_ops
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
delayed_op
;
sorted_ops
.
reserve
(
topo_ops
.
size
());
ResetState
();
auto
insert_delayed_op
=
[
&
](
const
std
::
string
&
var_name
,
int
dev_id
)
{
sharded_var_device_
.
emplace
(
var_name
,
dev_id
);
if
(
delayed_op
.
count
(
var_name
))
{
auto
&
ops
=
delayed_op
.
at
(
var_name
);
sorted_ops
.
insert
(
sorted_ops
.
end
(),
ops
.
begin
(),
ops
.
end
());
delayed_op
.
at
(
var_name
).
clear
();
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
};
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
],
*
sharded_var_device
);
for
(
auto
&
varname
:
output_var_names
)
{
for
(
ir
::
Node
*
node
:
topo_ops
)
{
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
int
op_dev_id
=
GetOpDeviceID
(
node
,
&
delayed_op
);
if
(
op_dev_id
>
-
1
)
{
// This op only runs on one specific device.
sorted_ops
.
emplace_back
(
node
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
insert_delayed_op
(
n
->
Name
(),
op_dev_id
);
}
}
else
if
(
op_dev_id
==
-
1
)
{
// This op runs on all devices, and its output may have parameter's
// gradients.
sorted_ops
.
emplace_back
(
node
);
bool
is_bk_op
=
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kBackward
));
if
(
!
is_bk_op
)
continue
;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
std
::
vector
<
std
::
string
>
backward_vars
;
try
{
backward_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
}
catch
(
boost
::
bad_get
e
)
{
}
PADDLE_ENFORCE_EQ
(
backward_vars
.
size
()
%
2
,
0
);
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
auto
&
g_name
=
backward_vars
[
i
+
1
];
size_t
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
insert_delayed_op
(
g_name
,
static_cast
<
int
>
(
cur_device_id
));
}
}
else
if
(
op_dev_id
==
-
2
)
{
// The Op on which the Op depends has not yet been generated.
}
}
}
else
{
LOG
(
ERROR
)
<<
"got unexpected dist op: "
<<
node
->
Op
()
->
Type
();
PADDLE_THROW
(
"the distribute training related op should be in [split_byref, "
"concat]."
);
}
}
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
PADDLE_ENFORCE_EQ
(
sorted_ops
.
size
(),
topo_ops
.
size
());
"can not find right place for distributed op: %s"
,
node
->
Op
()
->
Type
());
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
ResetState
();
return
op_dev_id
;
return
sorted_ops
;
}
void
DistSSAGraphBuilder
::
Init
()
const
{
MultiDevSSAGraphBuilderBase
::
Init
();
ResetState
();
}
void
DistSSAGraphBuilder
::
ResetState
()
const
{
BalanceVarSSAGraphBuilder
::
ResetState
();
bcast_var_name_set_
.
clear
();
bcast_var_name_set_
.
resize
(
places_
.
size
());
}
bool
DistSSAGraphBuilder
::
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
bool
insert_op
=
false
;
if
(
OpHaveRole
(
*
node
,
OpRole
::
kRPC
))
{
int
op_dev_id
=
CreateRPCOp
(
result
,
node
);
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"Can not schedule the RPC operator to the right place."
);
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
auto
recv_vars_attr
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE
(
recv_vars_attr
.
size
()
==
2UL
);
// [parameter, gradient]
if
(
recv_vars_attr
[
0
].
find
(
".block"
)
==
std
::
string
::
npos
)
{
bcast_var_name_set_
[
op_dev_id
].
emplace
(
recv_vars_attr
[
0
]);
}
}
insert_op
=
true
;
need_broadcast_var_
=
true
;
}
else
if
(
OpHaveRole
(
*
node
,
OpRole
::
kDist
))
{
int
op_dev_id
=
CreateDistTrainOp
(
result
,
node
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
bcast_var_name_set_
[
op_dev_id
].
emplace
(
origin_param_name
);
}
insert_op
=
true
;
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
node
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
sharded_var_device_
.
emplace
(
n
->
Name
(),
op_dev_id
);
}
insert_op
=
true
;
}
}
return
insert_op
;
}
}
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
...
@@ -775,13 +768,11 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
...
@@ -775,13 +768,11 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
}
}
// Create RPC related op handles that connects its in ops and out ops.
// Create RPC related op handles that connects its in ops and out ops.
int
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
int
DistSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
()
,
*
sharded_var_device
);
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
());
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
"This hack no longer holds, please fix."
);
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
// the variable name which contains .block means it was splited by
...
@@ -799,9 +790,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -799,9 +790,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
VLOG
(
10
)
<<
"send grad "
<<
input_var_names
[
0
]
<<
" origin "
VLOG
(
10
)
<<
"send grad "
<<
input_var_names
[
0
]
<<
" origin "
<<
send_param_grad
[
1
]
<<
" place: "
<<
op_dev_id
;
<<
send_param_grad
[
1
]
<<
" place: "
<<
op_dev_id
;
for
(
auto
&
varname
:
input_var_names
)
{
for
(
auto
&
varname
:
input_var_names
)
{
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
sharded_var_device
_
.
emplace
(
varname
,
op_dev_id
);
}
}
sharded_var_device
->
emplace
(
send_param_grad
[
1
],
op_dev_id
);
sharded_var_device
_
.
emplace
(
send_param_grad
[
1
],
op_dev_id
);
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
std
::
vector
<
std
::
string
>
output_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
...
@@ -811,7 +802,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -811,7 +802,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
auto
recv_param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
auto
recv_param_grad
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
if
(
recv_param_grad
.
size
()
==
2U
)
{
if
(
recv_param_grad
.
size
()
==
2U
)
{
op_dev_id
=
GetVarDeviceID
(
recv_param_grad
[
1
]
,
*
sharded_var_device
);
op_dev_id
=
GetVarDeviceID
(
recv_param_grad
[
1
]);
VLOG
(
10
)
<<
"recv param "
<<
recv_param_grad
[
0
]
VLOG
(
10
)
<<
"recv param "
<<
recv_param_grad
[
0
]
<<
" get grad place: "
<<
recv_param_grad
[
1
]
<<
" get grad place: "
<<
recv_param_grad
[
1
]
<<
" place: "
<<
op_dev_id
;
<<
" place: "
<<
op_dev_id
;
...
@@ -819,7 +810,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -819,7 +810,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
}
}
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
sharded_var_device
->
emplace
(
varname
,
op_dev_id
);
sharded_var_device
_
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
{
}
else
{
// send_barrier, fetch_barrier will run on place 0;
// send_barrier, fetch_barrier will run on place 0;
...
@@ -846,7 +837,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -846,7 +837,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
int
outvar_dev_id
=
op_dev_id
;
int
outvar_dev_id
=
op_dev_id
;
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
outvar_dev_id
=
GetVarDeviceID
(
output
->
Name
()
,
*
sharded_var_device
);
outvar_dev_id
=
GetVarDeviceID
(
output
->
Name
());
PADDLE_ENFORCE_NE
(
outvar_dev_id
,
-
1
,
"output name %s"
,
output
->
Name
());
PADDLE_ENFORCE_NE
(
outvar_dev_id
,
-
1
,
"output name %s"
,
output
->
Name
());
}
}
p
=
places_
[
outvar_dev_id
];
p
=
places_
[
outvar_dev_id
];
...
@@ -863,29 +854,124 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -863,29 +854,124 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
return
op_dev_id
;
return
op_dev_id
;
}
}
bool
MultiDevSSAGraphBuilder
::
IsSparseGradient
(
const
std
::
string
&
og
)
const
{
int
DistSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
PADDLE_ENFORCE
(
all_vars_
.
count
(
og
)
!=
0
);
ir
::
Node
*
node
)
const
{
if
(
all_vars_
.
at
(
og
)
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
)
{
int
op_dev_id
=
-
1
;
return
true
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
input_var_names
.
push_back
(
input
->
Name
());
}
}
return
false
;
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
output_var_names
.
push_back
(
output
->
Name
());
}
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
node
->
Op
()
->
Type
()
==
"split_selected_rows"
||
node
->
Op
()
->
Type
()
==
"split_ids"
)
{
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
sharded_var_device_
.
emplace
(
varname
,
op_dev_id
);
}
}
for
(
auto
&
varname
:
output_var_names
)
{
sharded_var_device_
.
emplace
(
varname
,
op_dev_id
);
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
for
(
auto
&
varname
:
output_var_names
)
{
sharded_var_device_
.
emplace
(
varname
,
op_dev_id
);
}
}
else
{
LOG
(
ERROR
)
<<
"got unexpected dist op: "
<<
node
->
Op
()
->
Type
();
PADDLE_THROW
(
"the distribute training related op should be in [split_byref, "
"concat]."
);
}
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find right place for distributed op: %s"
,
node
->
Op
()
->
Type
());
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
return
op_dev_id
;
}
}
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
ir
::
Node
*
node
)
const
{
void
DistSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
return
boost
::
get
<
int
>
(
const
std
::
string
&
p_name
,
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
const
std
::
string
&
g_name
)
const
{
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
size_t
cur_device_id
=
0
;
static_cast
<
int
>
(
OpRole
::
kLoss
))
&&
switch
(
strategy_
.
reduce_
)
{
!
loss_var_name_
.
empty
();
// If loss_var is empty. This is test mode
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
result
,
g_name
,
cur_device_id
);
sharded_var_device_
.
emplace
(
g_name
,
cur_device_id
);
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
if
(
IsSparseGradient
(
g_name
))
{
CreateReduceOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
}
else
{
CreateAllReduceOp
(
result
,
g_name
);
}
break
;
default:
LOG
(
FATAL
)
<<
"Unknown reduce strategy."
;
break
;
}
}
void
DistSSAGraphBuilder
::
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
{
if
(
need_broadcast_var_
||
(
UseGPU
()
&&
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
))
{
if
(
strategy_
.
fuse_broadcast_op_
)
{
CreateFusedBroadcastOp
(
result
,
bcast_var_name_set_
);
}
else
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_var_name_set_
.
size
();
++
dev_id
)
{
auto
&
to_bcast_set
=
bcast_var_name_set_
[
dev_id
];
for
(
auto
&
bcast_name
:
to_bcast_set
)
{
CreateBroadcastOp
(
result
,
bcast_name
,
dev_id
);
}
}
}
}
}
std
::
unordered_set
<
std
::
string
>
&
MultiDevSSAGraphBuilder
()
{
static
std
::
unordered_set
<
std
::
string
>
regs
;
return
regs
;
}
}
static
int
MultiDevSSAGraphBuilderRegister
(
const
std
::
string
&
builder_mode
)
{
MultiDevSSAGraphBuilder
().
insert
(
builder_mode
);
return
0
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
multi_devices_pass
,
#define REGISTER_MULTI_DEVICES_PASS(pass_name, pass_class) \
paddle
::
framework
::
details
::
MultiDevSSAGraphBuilder
)
STATIC_ASSERT_GLOBAL_NAMESPACE( \
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLossVarName
)
_reg_ssa_graph_builder_##pass_name, \
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
"REGISTER_MULTI_DEVICES_PASS must be called in global namespace."); \
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
)
int _reg_ssa_graph_builder_entry_##pass_name = \
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kStrategy
)
paddle::framework::details::MultiDevSSAGraphBuilderRegister(#pass_name); \
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kNRanks
);
REGISTER_PASS(pass_name, pass_class) \
.RequirePassAttr(paddle::framework::details::kLossVarName) \
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kNRanks)
REGISTER_MULTI_DEVICES_PASS
(
reduce_mode_multi_devices_pass
,
paddle
::
framework
::
details
::
ReduceSSAGraphBuilder
);
REGISTER_MULTI_DEVICES_PASS
(
allreduce_mode_multi_devices_pass
,
paddle
::
framework
::
details
::
AllReduceSSAGraphBuilder
);
REGISTER_MULTI_DEVICES_PASS
(
dist_multi_devices_pass
,
paddle
::
framework
::
details
::
DistSSAGraphBuilder
);
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
68a07328
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
...
@@ -30,78 +31,70 @@ namespace framework {
...
@@ -30,78 +31,70 @@ namespace framework {
class
Scope
;
class
Scope
;
namespace
details
{
namespace
details
{
class
MultiDevSSAGraphBuilder
:
public
ir
::
Pass
{
constexpr
char
kLossVarName
[]
=
"loss_var_name"
;
constexpr
char
kPlaces
[]
=
"places"
;
constexpr
char
kLocalScopes
[]
=
"local_scopes"
;
constexpr
char
kStrategy
[]
=
"strategy"
;
constexpr
char
kNRanks
[]
=
"nranks"
;
class
MultiDevSSAGraphBuilderBase
:
public
ir
::
Pass
{
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
private:
virtual
void
Init
()
const
;
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
void
Init
()
const
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
virtual
std
::
vector
<
ir
::
Node
*>
SortOperations
(
const
ir
::
Graph
&
graph
)
const
;
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
int
GetVarDeviceID
(
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
varname
,
const
std
::
string
&
g_name
)
const
=
0
;
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
=
0
;
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
=
0
;
int
CreateRPCOp
(
bool
UseGPU
()
const
;
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
;
bool
NeedCollectiveOps
()
const
;
int
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
std
::
unordered_map
<
std
::
string
,
int
>
*
sharded_var_device
)
const
;
void
CreateComputationalOps
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateComputationalOps
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
num_places
)
const
;
size_t
num_places
)
const
;
void
CreateScaleLossGradOp
(
ir
::
Graph
*
result
,
void
CreateScaleLossGradOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
loss_grad_name
,
const
std
::
string
&
loss_grad_name
,
ir
::
Node
*
out_var_node
,
ir
::
Node
*
out_var_node
,
size_t
loss_scale
,
proto
::
VarType
::
Type
dtype
)
const
;
proto
::
VarType
::
Type
dtype
)
const
;
VarHandle
*
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
VarHandle
*
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
;
int
dst_dev_id
)
const
;
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
dev_id
)
const
;
int
dev_id
)
const
;
int
GetOpDeviceID
(
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
ir
::
Node
*
node
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
sharded_var_device
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertDataBalanceOp
(
ir
::
Graph
*
result
,
void
CreateAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
const
std
::
vector
<
std
::
string
>
&
datas
)
const
;
void
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
void
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
;
size_t
src_dev_id
)
const
;
void
InsertScaleLossGradOp
(
ir
::
Graph
*
result
,
const
ir
::
Node
*
node
)
const
;
void
CreateFusedBroadcastOp
(
void
CreateFusedBroadcastOp
(
ir
::
Graph
*
result
,
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
bcast_varnames
)
const
;
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
bcast_varnames
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
size_t
GetAppropriateDeviceID
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
;
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
const
platform
::
Place
&
p
)
const
;
const
platform
::
Place
&
p
)
const
;
std
::
vector
<
ir
::
Node
*>
SortForReduceMode
(
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
vector
<
ir
::
Node
*>
&
)
const
;
size_t
device_id
)
const
;
int
GetOpDeviceID
(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
ir
::
Node
*
node
,
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
const
std
::
unordered_map
<
std
::
string
,
int
>
&
shared_var_device
,
#endif
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
*
delay_ops
)
const
;
mutable
std
::
string
loss_var_name_
;
mutable
std
::
string
loss_var_name_
;
mutable
std
::
vector
<
platform
::
Place
>
places_
;
mutable
std
::
vector
<
platform
::
Place
>
places_
;
...
@@ -109,8 +102,83 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
...
@@ -109,8 +102,83 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable
BuildStrategy
strategy_
;
mutable
BuildStrategy
strategy_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
};
class
AllReduceSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
protected:
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
return
false
;
}
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
{}
};
class
BalanceVarSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
protected:
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
size_t
GetAppropriateDeviceID
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
;
virtual
void
ResetState
()
const
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
sharded_var_device_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
};
};
class
ReduceSSAGraphBuilder
:
public
BalanceVarSSAGraphBuilder
{
protected:
virtual
void
Init
()
const
;
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
;
virtual
std
::
vector
<
ir
::
Node
*>
SortOperations
(
const
ir
::
Graph
&
graph
)
const
;
virtual
void
ResetState
()
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
*
delay_ops
)
const
;
std
::
vector
<
ir
::
Node
*>
SortForReduceMode
(
const
std
::
vector
<
ir
::
Node
*>
&
topo_ops
)
const
;
mutable
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set_
;
};
class
DistSSAGraphBuilder
:
public
BalanceVarSSAGraphBuilder
{
protected:
virtual
void
Init
()
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
;
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
;
virtual
void
ResetState
()
const
;
int
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
int
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
mutable
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set_
;
mutable
bool
need_broadcast_var_
{
false
};
};
std
::
unordered_set
<
std
::
string
>
&
MultiDevSSAGraphBuilder
();
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/naive_executor.cc
浏览文件 @
68a07328
...
@@ -40,14 +40,14 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
...
@@ -40,14 +40,14 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
void
NaiveExecutor
::
Run
()
{
void
NaiveExecutor
::
Run
()
{
#ifndef PADDLE_ON_INFERENCE
#ifndef PADDLE_ON_INFERENCE
LOG_FIRST_N
(
WARNING
,
1
5
)
<<
"The NaiveExecutor can not work properly if the "
LOG_FIRST_N
(
WARNING
,
5
)
<<
"The NaiveExecutor can not work properly if the "
"cmake flag ON_INFER is not set."
;
"cmake flag ON_INFER is not set."
;
LOG_FIRST_N
(
WARNING
,
1
5
)
<<
"Unlike the training phase, all the scopes and "
LOG_FIRST_N
(
WARNING
,
5
)
<<
"Unlike the training phase, all the scopes and "
"variables will be reused to save the allocation "
"variables will be reused to save the allocation "
"overhead."
;
"overhead."
;
LOG_FIRST_N
(
WARNING
,
1
5
)
<<
"Please re-compile the inference library by "
LOG_FIRST_N
(
WARNING
,
5
)
<<
"Please re-compile the inference library by "
"setting the cmake flag ON_INFER=ON if you are "
"setting the cmake flag ON_INFER=ON if you are "
"running Paddle Inference"
;
"running Paddle Inference"
;
#endif // PADDLE_ON_INFERENCE
#endif // PADDLE_ON_INFERENCE
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
VLOG
(
3
)
<<
std
::
this_thread
::
get_id
()
<<
" run "
<<
op
->
Type
()
VLOG
(
3
)
<<
std
::
this_thread
::
get_id
()
<<
" run "
<<
op
->
Type
()
...
...
paddle/fluid/framework/ngraph_operator.cc
浏览文件 @
68a07328
...
@@ -539,7 +539,7 @@ void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
...
@@ -539,7 +539,7 @@ void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
}
}
}
}
backend_
->
call
(
ngraph_function_
,
t_out
,
t_in
);
backend_
->
call
(
backend_
->
compile
(
ngraph_function_
)
,
t_out
,
t_in
);
}
// NgraphEngine::RunImpl
}
// NgraphEngine::RunImpl
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/analysis/argument.h
浏览文件 @
68a07328
...
@@ -123,8 +123,6 @@ struct Argument {
...
@@ -123,8 +123,6 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
use_gpu
,
UseGPU
,
bool
);
DECL_ARGUMENT_FIELD
(
use_gpu
,
UseGPU
,
bool
);
DECL_ARGUMENT_FIELD
(
gpu_device_id
,
GPUDeviceId
,
int
);
DECL_ARGUMENT_FIELD
(
gpu_device_id
,
GPUDeviceId
,
int
);
DECL_ARGUMENT_FIELD
(
use_tensorrt
,
UseTensorRT
,
bool
);
DECL_ARGUMENT_FIELD
(
use_tensorrt
,
UseTensorRT
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_node_teller
,
TensorRtNodeTeller
,
std
::
function
<
bool
(
const
framework
::
ir
::
Node
*
)
>
);
DECL_ARGUMENT_FIELD
(
tensorrt_max_batch_size
,
TensorRtMaxBatchSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_max_batch_size
,
TensorRtMaxBatchSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_workspace_size
,
TensorRtWorkspaceSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_workspace_size
,
TensorRtWorkspaceSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_min_subgraph_size
,
TensorRtMinSubgraphSize
,
int
);
DECL_ARGUMENT_FIELD
(
tensorrt_min_subgraph_size
,
TensorRtMinSubgraphSize
,
int
);
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
68a07328
...
@@ -49,13 +49,6 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -49,13 +49,6 @@ void IRPassManager::CreatePasses(Argument *argument,
for
(
const
std
::
string
&
pass_name
:
passes
)
{
for
(
const
std
::
string
&
pass_name
:
passes
)
{
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
// Set some pass attributes.
if
(
pass_name
==
"ir_analysis_pass"
)
{
pass
->
Set
(
"tensorrt_node_teller"
,
new
SubgraphDetector
::
NodeInsideSubgraphTeller
(
argument
->
tensorrt_node_teller
()));
}
if
(
pass_name
==
"graph_viz_pass"
)
{
if
(
pass_name
==
"graph_viz_pass"
)
{
std
::
string
dot_file_path
=
std
::
to_string
(
pass_num
)
+
"_ir_"
+
std
::
string
dot_file_path
=
std
::
to_string
(
pass_num
)
+
"_ir_"
+
(
pre_pass
.
empty
()
?
"origin"
:
pre_pass
)
+
(
pre_pass
.
empty
()
?
"origin"
:
pre_pass
)
+
...
@@ -70,9 +63,6 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -70,9 +63,6 @@ void IRPassManager::CreatePasses(Argument *argument,
}
}
if
(
pass_name
==
"tensorrt_subgraph_pass"
)
{
if
(
pass_name
==
"tensorrt_subgraph_pass"
)
{
PADDLE_ENFORCE
(
argument
->
tensorrt_node_teller_valid
());
pass
->
SetNotOwned
(
"tensorrt_node_teller"
,
argument
->
tensorrt_node_teller_ptr
());
pass
->
Set
(
"workspace_size"
,
new
int
(
argument
->
tensorrt_workspace_size
()));
pass
->
Set
(
"workspace_size"
,
new
int
(
argument
->
tensorrt_workspace_size
()));
pass
->
Set
(
"max_batch_size"
,
new
int
(
argument
->
tensorrt_max_batch_size
()));
pass
->
Set
(
"max_batch_size"
,
new
int
(
argument
->
tensorrt_max_batch_size
()));
pass
->
Set
(
"min_subgraph_size"
,
pass
->
Set
(
"min_subgraph_size"
,
...
...
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
浏览文件 @
68a07328
cc_library
(
subgraph_detector SRCS subgraph_detector.cc DEPS proto_desc
)
cc_library
(
subgraph_detector SRCS subgraph_detector.cc DEPS proto_desc
)
cc_library
(
tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_detector
)
set
(
analysis_deps
${
analysis_deps
}
subgraph_detector tensorrt_subgraph_pass
CACHE INTERNAL
""
)
set
(
pass_file
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/api/paddle_inference_pass.h
)
if
(
TENSORRT_FOUND
)
file
(
APPEND
${
pass_file
}
"USE_PASS(tensorrt_subgraph_pass);
\n
"
)
cc_library
(
tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_detector tensorrt_op_teller
)
set
(
INFER_IR_PASSES
${
INFER_IR_PASSES
}
tensorrt_subgraph_pass CACHE INTERNAL
""
)
set
(
analysis_deps
${
analysis_deps
}
subgraph_detector tensorrt_subgraph_pass
CACHE INTERNAL
""
)
set
(
pass_file
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/api/paddle_inference_pass.h
)
file
(
APPEND
${
pass_file
}
"USE_PASS(tensorrt_subgraph_pass);
\n
"
)
set
(
INFER_IR_PASSES
${
INFER_IR_PASSES
}
tensorrt_subgraph_pass CACHE INTERNAL
""
)
endif
()
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
68a07328
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -35,8 +36,10 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
...
@@ -35,8 +36,10 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
framework
::
ir
::
Graph
>
graph
)
const
{
framework
::
ir
::
FusePassBase
::
Init
(
"tensorrt_subgraph_pass"
,
graph
.
get
());
framework
::
ir
::
FusePassBase
::
Init
(
"tensorrt_subgraph_pass"
,
graph
.
get
());
auto
teller
=
auto
teller
=
[](
const
framework
::
ir
::
Node
*
node
)
{
Get
<
SubgraphDetector
::
NodeInsideSubgraphTeller
>
(
"tensorrt_node_teller"
);
if
(
!
node
->
IsOp
()
||
!
node
->
Op
())
return
false
;
return
tensorrt
::
OpTeller
::
Global
().
Tell
(
node
->
Op
()
->
Type
(),
*
node
->
Op
());
};
SubGraphFuser
fuser
(
graph
.
get
(),
teller
,
SubGraphFuser
fuser
(
graph
.
get
(),
teller
,
Get
<
int
>
(
"min_subgraph_size"
)
/*min subgraph size*/
);
Get
<
int
>
(
"min_subgraph_size"
)
/*min subgraph size*/
);
...
@@ -232,7 +235,6 @@ std::vector<std::string> ExtractParameters(
...
@@ -232,7 +235,6 @@ std::vector<std::string> ExtractParameters(
REGISTER_PASS
(
tensorrt_subgraph_pass
,
REGISTER_PASS
(
tensorrt_subgraph_pass
,
paddle
::
inference
::
analysis
::
TensorRtSubgraphPass
)
paddle
::
inference
::
analysis
::
TensorRtSubgraphPass
)
.
RequirePassAttr
(
"tensorrt_node_teller"
)
.
RequirePassAttr
(
"max_batch_size"
)
.
RequirePassAttr
(
"max_batch_size"
)
.
RequirePassAttr
(
"workspace_size"
)
.
RequirePassAttr
(
"workspace_size"
)
.
RequirePassAttr
(
"min_subgraph_size"
);
.
RequirePassAttr
(
"min_subgraph_size"
);
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc
浏览文件 @
68a07328
...
@@ -27,9 +27,6 @@ namespace analysis {
...
@@ -27,9 +27,6 @@ namespace analysis {
void
IrAnalysisComposePass
::
RunImpl
(
Argument
*
argument
)
{
void
IrAnalysisComposePass
::
RunImpl
(
Argument
*
argument
)
{
ARGUMENT_CHECK_FIELD
(
argument
,
ir_analysis_passes
);
ARGUMENT_CHECK_FIELD
(
argument
,
ir_analysis_passes
);
if
(
argument
->
use_tensorrt_valid
()
&&
argument
->
use_tensorrt
())
{
InitTensorRTAttrs
(
argument
);
}
ApplyIrPasses
(
argument
);
ApplyIrPasses
(
argument
);
CollectFusionStatis
(
argument
);
CollectFusionStatis
(
argument
);
}
}
...
@@ -38,26 +35,6 @@ std::string IrAnalysisComposePass::repr() const {
...
@@ -38,26 +35,6 @@ std::string IrAnalysisComposePass::repr() const {
return
"ir-analysis-compose-pass"
;
return
"ir-analysis-compose-pass"
;
}
}
void
IrAnalysisComposePass
::
InitTensorRTAttrs
(
Argument
*
argument
)
{
if
(
argument
->
use_tensorrt_valid
()
&&
argument
->
use_tensorrt
())
{
LOG
(
INFO
)
<<
"Initing TensorRT pass"
;
argument
->
SetTensorRtNodeTeller
([](
const
framework
::
ir
::
Node
*
node
)
{
std
::
unordered_set
<
std
::
string
>
teller_set
(
{
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"elementwise_add"
,
"elementwise_mul"
,
"dropout"
,
"split"
,
"prelu"
,
"conv2d_transpose"
,
"leaky_relu"
});
if
(
!
node
->
IsOp
())
return
false
;
if
(
teller_set
.
count
(
node
->
Op
()
->
Type
()))
{
return
true
;
}
else
{
return
false
;
}
});
}
}
void
IrAnalysisComposePass
::
ApplyIrPasses
(
Argument
*
argument
)
{
void
IrAnalysisComposePass
::
ApplyIrPasses
(
Argument
*
argument
)
{
std
::
vector
<
std
::
string
>
passes
({
std
::
vector
<
std
::
string
>
passes
({
"ir_graph_build_pass"
,
"ir_analysis_pass"
,
"ir_graph_build_pass"
,
"ir_analysis_pass"
,
...
...
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h
浏览文件 @
68a07328
...
@@ -33,8 +33,6 @@ class IrAnalysisComposePass : public AnalysisPass {
...
@@ -33,8 +33,6 @@ class IrAnalysisComposePass : public AnalysisPass {
std
::
string
repr
()
const
override
;
std
::
string
repr
()
const
override
;
private:
private:
void
InitTensorRTAttrs
(
Argument
*
argument
);
void
ApplyIrPasses
(
Argument
*
argument
);
void
ApplyIrPasses
(
Argument
*
argument
);
void
CollectFusionStatis
(
Argument
*
argument
);
void
CollectFusionStatis
(
Argument
*
argument
);
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
68a07328
...
@@ -14,86 +14,101 @@
...
@@ -14,86 +14,101 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle
_pass_builder.h" // NOLINT
#include "paddle
/fluid/platform/gpu_info.h"
namespace
paddle
{
namespace
paddle
{
PassStrategy
*
contrib
::
AnalysisConfig
::
pass_builder
()
const
{
PassStrategy
*
contrib
::
AnalysisConfig
::
pass_builder
()
const
{
PADDLE_ENFORCE
(
if
(
!
pass_builder_
.
get
())
{
pass_builder_
.
get
(),
if
(
use_gpu_
)
{
"Should call constructor first, that will init the pass_builder_."
);
LOG
(
INFO
)
<<
"Create GPU IR passes"
;
pass_builder_
.
reset
(
new
GpuPassStrategy
);
}
else
{
LOG
(
INFO
)
<<
"Create CPU IR passes"
;
pass_builder_
.
reset
(
new
CpuPassStrategy
);
}
}
else
if
(
pass_builder_
->
use_gpu
()
^
use_gpu
())
{
LOG
(
WARNING
)
<<
"The use_gpu flag is not compatible between Config and "
"PassBuilder, the flags are "
<<
use_gpu
()
<<
" "
<<
pass_builder_
->
use_gpu
();
LOG
(
WARNING
)
<<
"Please make them compatible, still use the existing "
"PassBuilder."
;
}
return
pass_builder_
.
get
();
return
pass_builder_
.
get
();
}
}
contrib
::
AnalysisConfig
::
AnalysisConfig
(
bool
use_gpu
)
{
contrib
::
AnalysisConfig
::
AnalysisConfig
(
const
std
::
string
&
model_dir
)
{
this
->
use_gpu
=
use_gpu
;
model_dir_
=
model_dir
;
if
(
use_gpu
)
{
}
pass_builder_
.
reset
(
new
GpuPassStrategy
);
contrib
::
AnalysisConfig
::
AnalysisConfig
(
const
std
::
string
&
prog_file
,
}
else
{
const
std
::
string
&
params_file
)
{
pass_builder_
.
reset
(
new
CpuPassStrategy
);
prog_file_
=
prog_file
;
}
params_file_
=
params_file
;
}
void
contrib
::
AnalysisConfig
::
SetModel
(
const
std
::
string
&
prog_file_path
,
const
std
::
string
&
params_file_path
)
{
prog_file_
=
prog_file_path
;
params_file_
=
params_file_path
;
}
void
contrib
::
AnalysisConfig
::
EnableUseGpu
(
uint64_t
memory_pool_init_size_mb
,
int
device_id
)
{
#ifdef PADDLE_WITH_CUDA
use_gpu_
=
true
;
memory_pool_init_size_mb_
=
memory_pool_init_size_mb
;
device_id_
=
device_id
;
#else
LOG
(
ERROR
)
<<
"Please compile with gpu to EnableGpu"
;
use_gpu_
=
false
;
#endif
}
}
void
contrib
::
AnalysisConfig
::
DisableGpu
()
{
use_gpu_
=
false
;
}
contrib
::
AnalysisConfig
::
AnalysisConfig
(
const
contrib
::
AnalysisConfig
&
other
)
{
contrib
::
AnalysisConfig
::
AnalysisConfig
(
const
contrib
::
AnalysisConfig
&
other
)
{
// fields from Config
#define CP_MEMBER(member__) member__ = other.member__;
model_dir
=
other
.
model_dir
;
// fields from NativeConfig
// Model related.
use_gpu
=
other
.
use_gpu
;
CP_MEMBER
(
model_dir_
);
device
=
other
.
device
;
CP_MEMBER
(
prog_file_
);
fraction_of_gpu_memory
=
other
.
fraction_of_gpu_memory
;
CP_MEMBER
(
params_file_
);
prog_file
=
other
.
prog_file
;
CP_MEMBER
(
model_from_memory_
);
// the memory model reuses prog_file_ and
param_file
=
other
.
param_file
;
// params_file_ fields.
specify_input_name
=
other
.
specify_input_name
;
// Gpu releated.
cpu_math_library_num_threads_
=
other
.
cpu_math_library_num_threads_
;
CP_MEMBER
(
use_gpu_
);
// fields from this.
CP_MEMBER
(
device_id_
);
enable_ir_optim
=
other
.
enable_ir_optim
;
CP_MEMBER
(
memory_pool_init_size_mb_
);
// For mkldnn
// TensorRT releated.
use_mkldnn_
=
other
.
use_mkldnn_
;
CP_MEMBER
(
use_tensorrt_
);
mkldnn_enabled_op_types_
=
other
.
mkldnn_enabled_op_types_
;
CP_MEMBER
(
tensorrt_workspace_size_
);
CP_MEMBER
(
tensorrt_max_batchsize_
);
use_feed_fetch_ops
=
other
.
use_feed_fetch_ops
;
CP_MEMBER
(
tensorrt_min_subgraph_size_
);
use_tensorrt_
=
other
.
use_tensorrt_
;
// MKLDNN releated.
tensorrt_max_batchsize_
=
other
.
tensorrt_max_batchsize_
;
CP_MEMBER
(
use_mkldnn_
);
tensorrt_workspace_size_
=
other
.
tensorrt_workspace_size_
;
CP_MEMBER
(
mkldnn_enabled_op_types_
);
tensorrt_min_subgraph_size_
=
other
.
tensorrt_min_subgraph_size_
;
model_from_memory_
=
other
.
model_from_memory_
;
// Ir related.
CP_MEMBER
(
enable_ir_optim_
);
if
(
use_gpu
)
{
CP_MEMBER
(
use_feed_fetch_ops_
);
CP_MEMBER
(
ir_debug_
);
CP_MEMBER
(
specify_input_name_
);
CP_MEMBER
(
cpu_math_library_num_threads_
);
CP_MEMBER
(
serialized_info_cache_
);
if
(
use_gpu_
)
{
pass_builder_
.
reset
(
new
GpuPassStrategy
(
pass_builder_
.
reset
(
new
GpuPassStrategy
(
*
static_cast
<
GpuPassStrategy
*>
(
other
.
pass_builder
())));
*
static_cast
<
GpuPassStrategy
*>
(
other
.
pass_builder
())));
}
else
{
}
else
{
pass_builder_
.
reset
(
new
CpuPassStrategy
(
pass_builder_
.
reset
(
new
CpuPassStrategy
(
*
static_cast
<
CpuPassStrategy
*>
(
other
.
pass_builder
())));
*
static_cast
<
CpuPassStrategy
*>
(
other
.
pass_builder
())));
}
}
}
contrib
::
AnalysisConfig
::
AnalysisConfig
(
contrib
::
AnalysisConfig
&&
other
)
{
#undef CP_MEMBER
// fields from Config
model_dir
=
other
.
model_dir
;
// fields from NativeConfig
use_gpu
=
other
.
use_gpu
;
device
=
other
.
device
;
fraction_of_gpu_memory
=
other
.
fraction_of_gpu_memory
;
prog_file
=
other
.
prog_file
;
param_file
=
other
.
param_file
;
specify_input_name
=
other
.
specify_input_name
;
cpu_math_library_num_threads_
=
other
.
cpu_math_library_num_threads_
;
// fields from this.
enable_ir_optim
=
other
.
enable_ir_optim
;
// For mkldnn
use_mkldnn_
=
other
.
use_mkldnn_
;
mkldnn_enabled_op_types_
=
other
.
mkldnn_enabled_op_types_
;
use_feed_fetch_ops
=
other
.
use_feed_fetch_ops
;
use_tensorrt_
=
other
.
use_tensorrt_
;
tensorrt_max_batchsize_
=
other
.
tensorrt_max_batchsize_
;
tensorrt_workspace_size_
=
other
.
tensorrt_workspace_size_
;
tensorrt_min_subgraph_size_
=
other
.
tensorrt_min_subgraph_size_
;
model_from_memory_
=
other
.
model_from_memory_
;
pass_builder_
=
std
::
move
(
other
.
pass_builder_
);
}
}
void
contrib
::
AnalysisConfig
::
EnableMKLDNN
()
{
void
contrib
::
AnalysisConfig
::
EnableMKLDNN
()
{
...
@@ -112,17 +127,90 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
...
@@ -112,17 +127,90 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
use_tensorrt_
=
true
;
use_tensorrt_
=
true
;
tensorrt_workspace_size_
=
workspace_size
;
tensorrt_workspace_size_
=
workspace_size
;
tensorrt_max_batchsize_
=
max_batch_size
;
tensorrt_max_batchsize_
=
max_batch_size
;
tensorrt_min_subgraph_size_
=
min_subgraph_size
;
}
// Append after the conv+affine_channel fuse pass.
pass_builder
()
->
InsertPass
(
3
,
"tensorrt_subgraph_pass"
);
void
contrib
::
AnalysisConfig
::
Update
()
{
auto
info
=
SerializeInfoCache
();
if
(
info
==
serialized_info_cache_
)
return
;
if
(
use_gpu_
)
{
pass_builder_
.
reset
(
new
GpuPassStrategy
);
}
else
{
pass_builder_
.
reset
(
new
CpuPassStrategy
);
}
if
(
use_tensorrt_
)
{
if
(
!
use_gpu_
)
{
LOG
(
ERROR
)
<<
"TensorRT engine is not available when EnableGpu() not actived."
;
}
else
{
// Append after the infer_clean pass.
pass_builder
()
->
InsertPass
(
1
,
"tensorrt_subgraph_pass"
);
}
}
if
(
use_mkldnn_
)
{
if
(
!
enable_ir_optim_
)
{
LOG
(
ERROR
)
<<
"EnableMKLDNN() only works when IR optimization is enabled."
;
}
#ifdef PADDLE_WITH_MKLDNN
pass_builder
()
->
EnableMKLDNN
();
use_mkldnn_
=
true
;
#else
LOG
(
ERROR
)
<<
"Please compile with MKLDNN first to use MKLDNN"
;
use_mkldnn_
=
false
;
#endif
}
if
(
ir_debug_
)
{
pass_builder
()
->
TurnOnDebug
();
}
}
std
::
string
contrib
::
AnalysisConfig
::
SerializeInfoCache
()
{
std
::
stringstream
ss
;
ss
<<
use_gpu_
;
ss
<<
memory_pool_init_size_mb_
;
ss
<<
use_tensorrt_
;
ss
<<
tensorrt_workspace_size_
;
ss
<<
tensorrt_max_batchsize_
;
ss
<<
use_mkldnn_
;
ss
<<
enable_ir_optim_
;
ss
<<
use_feed_fetch_ops_
;
ss
<<
ir_debug_
;
return
ss
.
str
();
}
void
contrib
::
AnalysisConfig
::
SetCpuMathLibraryNumThreads
(
int
cpu_math_library_num_threads
)
{
cpu_math_library_num_threads_
=
cpu_math_library_num_threads
;
}
float
contrib
::
AnalysisConfig
::
fraction_of_gpu_memory_for_pool
()
const
{
#ifdef PADDLE_WITH_CUDA
// Get the GPU memory details and calculate the fraction of memory for the
// GPU memory pool.
size_t
gpu_used
,
gpu_available
;
platform
::
GpuMemoryUsage
(
&
gpu_used
,
&
gpu_available
);
double
total_gpu_memory
=
(
gpu_used
+
gpu_available
)
/
1024.
/
1024.
;
float
fraction_of_gpu_memory
=
static_cast
<
double
>
(
memory_pool_init_size_mb
())
/
total_gpu_memory
;
return
fraction_of_gpu_memory
;
#else
return
0.
;
#endif
}
}
void
contrib
::
AnalysisConfig
::
SetModelBuffer
(
const
char
*
prog_buffer
,
void
contrib
::
AnalysisConfig
::
SetModelBuffer
(
const
char
*
prog_buffer
,
size_t
prog_buffer_size
,
size_t
prog_buffer_size
,
const
char
*
param_buffer
,
const
char
*
param_buffer
,
size_t
param_buffer_size
)
{
size_t
param_buffer_size
)
{
prog_file
=
std
::
string
(
prog_buffer
,
prog_buffer
+
prog_buffer_size
);
prog_file
_
=
std
::
string
(
prog_buffer
,
prog_buffer
+
prog_buffer_size
);
param
_file
=
std
::
string
(
param_buffer
,
param_buffer
+
param_buffer_size
);
param
s_file_
=
std
::
string
(
param_buffer
,
param_buffer
+
param_buffer_size
);
model_from_memory_
=
true
;
model_from_memory_
=
true
;
}
}
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
68a07328
...
@@ -33,6 +33,7 @@
...
@@ -33,6 +33,7 @@
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool
(
profile
);
DECLARE_bool
(
profile
);
...
@@ -59,8 +60,8 @@ bool AnalysisPredictor::Init(
...
@@ -59,8 +60,8 @@ bool AnalysisPredictor::Init(
if
(
FLAGS_profile
)
{
if
(
FLAGS_profile
)
{
LOG
(
WARNING
)
<<
"Profiler is actived, might affect the performance"
;
LOG
(
WARNING
)
<<
"Profiler is actived, might affect the performance"
;
LOG
(
INFO
)
<<
"You can turn off by set gflags '-profile false'"
;
LOG
(
INFO
)
<<
"You can turn off by set gflags '-profile false'"
;
auto
tracking_device
=
config_
.
use_gpu
?
platform
::
ProfilerState
::
kAll
auto
tracking_device
=
config_
.
use_gpu
()
?
platform
::
ProfilerState
::
kAll
:
platform
::
ProfilerState
::
kCPU
;
:
platform
::
ProfilerState
::
kCPU
;
platform
::
EnableProfiler
(
tracking_device
);
platform
::
EnableProfiler
(
tracking_device
);
}
}
...
@@ -112,7 +113,7 @@ bool AnalysisPredictor::PrepareProgram(
...
@@ -112,7 +113,7 @@ bool AnalysisPredictor::PrepareProgram(
// Optimize the program, and load parameters and modify them in the
// Optimize the program, and load parameters and modify them in the
// scope_.
// scope_.
// This will change the scope_ address.
// This will change the scope_ address.
if
(
config_
.
enable_ir_optim
)
{
if
(
config_
.
ir_optim
()
)
{
status_ir_optim_enabled_
=
true
;
status_ir_optim_enabled_
=
true
;
OptimizeInferenceProgram
();
OptimizeInferenceProgram
();
}
else
{
}
else
{
...
@@ -140,9 +141,9 @@ bool AnalysisPredictor::PrepareProgram(
...
@@ -140,9 +141,9 @@ bool AnalysisPredictor::PrepareProgram(
return
true
;
return
true
;
}
}
bool
AnalysisPredictor
::
CreateExecutor
()
{
bool
AnalysisPredictor
::
CreateExecutor
()
{
if
(
config_
.
use_gpu
)
{
if
(
config_
.
use_gpu
_
)
{
status_use_gpu_
=
true
;
status_use_gpu_
=
true
;
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
_id_
);
}
else
{
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
place_
=
paddle
::
platform
::
CPUPlace
();
}
}
...
@@ -151,7 +152,7 @@ bool AnalysisPredictor::CreateExecutor() {
...
@@ -151,7 +152,7 @@ bool AnalysisPredictor::CreateExecutor() {
}
}
bool
AnalysisPredictor
::
PrepareExecutor
()
{
bool
AnalysisPredictor
::
PrepareExecutor
()
{
executor_
->
Prepare
(
sub_scope_
,
*
inference_program_
,
0
,
executor_
->
Prepare
(
sub_scope_
,
*
inference_program_
,
0
,
config_
.
use_feed_fetch_ops
);
config_
.
use_feed_fetch_ops
_
);
PADDLE_ENFORCE_NOT_NULL
(
sub_scope_
);
PADDLE_ENFORCE_NOT_NULL
(
sub_scope_
);
...
@@ -250,7 +251,7 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
...
@@ -250,7 +251,7 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
}
}
input
.
set_lod
(
lod
);
input
.
set_lod
(
lod
);
int
idx
=
-
1
;
int
idx
=
-
1
;
if
(
config_
.
specify_input_name
)
{
if
(
config_
.
specify_input_name
_
)
{
auto
name
=
inputs
[
i
].
name
;
auto
name
=
inputs
[
i
].
name
;
if
(
feed_names_
.
find
(
name
)
==
feed_names_
.
end
())
{
if
(
feed_names_
.
find
(
name
)
==
feed_names_
.
end
())
{
LOG
(
ERROR
)
<<
"feed names from program do not have name: ["
<<
name
LOG
(
ERROR
)
<<
"feed names from program do not have name: ["
<<
name
...
@@ -314,22 +315,22 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
...
@@ -314,22 +315,22 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
void
AnalysisPredictor
::
OptimizeInferenceProgram
()
{
void
AnalysisPredictor
::
OptimizeInferenceProgram
()
{
status_program_optimized_
=
true
;
status_program_optimized_
=
true
;
argument_
.
SetUseGPU
(
config_
.
use_gpu
);
argument_
.
SetUseGPU
(
config_
.
use_gpu
()
);
argument_
.
SetGPUDeviceId
(
config_
.
device
);
argument_
.
SetGPUDeviceId
(
config_
.
gpu_device_id
()
);
argument_
.
SetModelFromMemory
(
config_
.
model_from_memory_
);
argument_
.
SetModelFromMemory
(
config_
.
model_from_memory_
);
// Analyze inference_program
// Analyze inference_program
if
(
!
config_
.
model_dir
.
empty
())
{
if
(
!
config_
.
model_dir
()
.
empty
())
{
argument_
.
SetModelDir
(
config_
.
model_dir
);
argument_
.
SetModelDir
(
config_
.
model_dir
()
);
}
else
{
}
else
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
!
config_
.
param
_file
.
empty
(),
!
config_
.
param
s_file
()
.
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
PADDLE_ENFORCE
(
!
config_
.
prog_file
()
.
empty
());
argument_
.
SetModelProgramPath
(
config_
.
prog_file
);
argument_
.
SetModelProgramPath
(
config_
.
prog_file
()
);
argument_
.
SetModelParamsPath
(
config_
.
param
_file
);
argument_
.
SetModelParamsPath
(
config_
.
param
s_file
()
);
}
}
if
(
config_
.
use_gpu
&&
config_
.
use_tensorrt_
)
{
if
(
config_
.
use_gpu
()
&&
config_
.
tensorrt_engine_enabled
()
)
{
argument_
.
SetUseTensorRT
(
true
);
argument_
.
SetUseTensorRT
(
true
);
argument_
.
SetTensorRtWorkspaceSize
(
config_
.
tensorrt_workspace_size_
);
argument_
.
SetTensorRtWorkspaceSize
(
config_
.
tensorrt_workspace_size_
);
argument_
.
SetTensorRtMaxBatchSize
(
config_
.
tensorrt_max_batchsize_
);
argument_
.
SetTensorRtMaxBatchSize
(
config_
.
tensorrt_max_batchsize_
);
...
@@ -341,7 +342,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
...
@@ -341,7 +342,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
}
}
auto
passes
=
config_
.
pass_builder
()
->
AllPasses
();
auto
passes
=
config_
.
pass_builder
()
->
AllPasses
();
if
(
!
config_
.
enable_ir_optim
)
passes
.
clear
();
if
(
!
config_
.
ir_optim
()
)
passes
.
clear
();
argument_
.
SetIrAnalysisPasses
(
passes
);
argument_
.
SetIrAnalysisPasses
(
passes
);
argument_
.
SetScopeNotOwned
(
const_cast
<
framework
::
Scope
*>
(
scope_
.
get
()));
argument_
.
SetScopeNotOwned
(
const_cast
<
framework
::
Scope
*>
(
scope_
.
get
()));
Analyzer
().
Run
(
&
argument_
);
Analyzer
().
Run
(
&
argument_
);
...
@@ -358,18 +359,26 @@ template <>
...
@@ -358,18 +359,26 @@ template <>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
AnalysisConfig
&
config
)
{
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
const
AnalysisConfig
&
config
)
{
VLOG
(
3
)
<<
"create AnalysisConfig"
;
VLOG
(
3
)
<<
"create AnalysisConfig"
;
if
(
config
.
use_gpu
)
{
if
(
config
.
use_gpu
()
)
{
// 1. GPU memeroy
// 1. GPU memeroy
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
config
.
memory_pool_init_size_mb
(),
0.
f
);
config
.
fraction_of_gpu_memory
,
0.
f
,
PADDLE_ENFORCE_GE
(
config
.
gpu_device_id
(),
0
,
"Invalid device id %d"
,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]"
);
config
.
gpu_device_id
());
PADDLE_ENFORCE_GE
(
config
.
device
,
0
,
"Invalid device id %d"
,
config
.
device
);
std
::
vector
<
std
::
string
>
flags
;
std
::
vector
<
std
::
string
>
flags
;
if
(
config
.
fraction_of_gpu_memory
>=
0.0
f
||
config
.
fraction_of_gpu_memory
<=
0.95
f
)
{
float
fraction_of_gpu_memory
=
config
.
fraction_of_gpu_memory_for_pool
();
if
(
fraction_of_gpu_memory
>
0.95
f
)
{
LOG
(
ERROR
)
<<
"Allocate too much memory for the GPU memory pool, assigned "
<<
config
.
memory_pool_init_size_mb
()
<<
" MB"
;
LOG
(
ERROR
)
<<
"Try to shink the value by setting AnalysisConfig::EnableGpu(...)"
;
}
if
(
fraction_of_gpu_memory
>=
0.0
f
||
fraction_of_gpu_memory
<=
0.95
f
)
{
flags
.
push_back
(
"dummpy"
);
flags
.
push_back
(
"dummpy"
);
std
::
string
flag
=
"--fraction_of_gpu_memory_to_use="
+
std
::
string
flag
=
"--fraction_of_gpu_memory_to_use="
+
std
::
to_string
(
config
.
fraction_of_gpu_memory
);
std
::
to_string
(
fraction_of_gpu_memory
);
flags
.
push_back
(
flag
);
flags
.
push_back
(
flag
);
VLOG
(
3
)
<<
"set flag: "
<<
flag
;
VLOG
(
3
)
<<
"set flag: "
<<
flag
;
framework
::
InitGflags
(
flags
);
framework
::
InitGflags
(
flags
);
...
@@ -443,22 +452,22 @@ bool AnalysisPredictor::ZeroCopyRun() {
...
@@ -443,22 +452,22 @@ bool AnalysisPredictor::ZeroCopyRun() {
bool
AnalysisPredictor
::
LoadProgramDesc
()
{
bool
AnalysisPredictor
::
LoadProgramDesc
()
{
// Initialize the inference program
// Initialize the inference program
std
::
string
filename
;
std
::
string
filename
;
if
(
!
config_
.
model_dir
.
empty
())
{
if
(
!
config_
.
model_dir
()
.
empty
())
{
filename
=
config_
.
model_dir
+
"/__model__"
;
filename
=
config_
.
model_dir
()
+
"/__model__"
;
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
}
else
if
(
!
config_
.
prog_file
().
empty
()
&&
!
config_
.
params_file
()
.
empty
())
{
// All parameters are saved in a single file.
// All parameters are saved in a single file.
// The file names should be consistent with that used
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
// in Python API `fluid.io.save_inference_model`.
filename
=
config_
.
prog_file
;
filename
=
config_
.
prog_file
()
;
}
else
{
}
else
{
if
(
config_
.
model_dir
.
empty
()
&&
config_
.
prog_file
.
empty
())
{
if
(
config_
.
model_dir
().
empty
()
&&
config_
.
prog_file
()
.
empty
())
{
LOG
(
ERROR
)
LOG
(
ERROR
)
<<
"Either model_dir or (prog_file, param_file) should be set."
;
<<
"Either model_dir or (prog_file, param_file) should be set."
;
return
false
;
return
false
;
}
}
LOG
(
ERROR
)
<<
string
::
Sprintf
(
LOG
(
ERROR
)
<<
string
::
Sprintf
(
"not valid model path '%s' or program path '%s'."
,
config_
.
model_dir
,
"not valid model path '%s' or program path '%s'."
,
config_
.
model_dir
()
,
config_
.
param
_file
);
config_
.
param
s_file
()
);
return
false
;
return
false
;
}
}
...
@@ -478,7 +487,7 @@ bool AnalysisPredictor::LoadProgramDesc() {
...
@@ -478,7 +487,7 @@ bool AnalysisPredictor::LoadProgramDesc() {
proto
.
ParseFromString
(
pb_content
);
proto
.
ParseFromString
(
pb_content
);
}
else
{
}
else
{
proto
.
ParseFromString
(
config_
.
prog_file
);
proto
.
ParseFromString
(
config_
.
prog_file
()
);
}
}
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
proto
));
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
proto
));
return
true
;
return
true
;
...
@@ -508,27 +517,27 @@ bool AnalysisPredictor::LoadParameters() {
...
@@ -508,27 +517,27 @@ bool AnalysisPredictor::LoadParameters() {
new_var
->
SetLoDLevel
(
var
->
GetLoDLevel
());
new_var
->
SetLoDLevel
(
var
->
GetLoDLevel
());
new_var
->
SetPersistable
(
true
);
new_var
->
SetPersistable
(
true
);
if
(
!
config_
.
param
_file
.
empty
())
{
if
(
!
config_
.
param
s_file
()
.
empty
())
{
params
.
push_back
(
new_var
->
Name
());
params
.
push_back
(
new_var
->
Name
());
}
else
{
}
else
{
// append_op
// append_op
framework
::
OpDesc
*
op
=
load_block
->
AppendOp
();
framework
::
OpDesc
*
op
=
load_block
->
AppendOp
();
op
->
SetType
(
"load"
);
op
->
SetType
(
"load"
);
op
->
SetOutput
(
"Out"
,
{
new_var
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
new_var
->
Name
()});
op
->
SetAttr
(
"file_path"
,
{
config_
.
model_dir
+
"/"
+
new_var
->
Name
()});
op
->
SetAttr
(
"file_path"
,
{
config_
.
model_dir
()
+
"/"
+
new_var
->
Name
()});
op
->
CheckAttrs
();
op
->
CheckAttrs
();
}
}
}
}
}
}
if
(
!
config_
.
param
_file
.
empty
())
{
if
(
!
config_
.
param
s_file
()
.
empty
())
{
// sort paramlist to have consistent ordering
// sort paramlist to have consistent ordering
std
::
sort
(
params
.
begin
(),
params
.
end
());
std
::
sort
(
params
.
begin
(),
params
.
end
());
// append just the load_combine op
// append just the load_combine op
framework
::
OpDesc
*
op
=
load_block
->
AppendOp
();
framework
::
OpDesc
*
op
=
load_block
->
AppendOp
();
op
->
SetType
(
"load_combine"
);
op
->
SetType
(
"load_combine"
);
op
->
SetOutput
(
"Out"
,
params
);
op
->
SetOutput
(
"Out"
,
params
);
op
->
SetAttr
(
"file_path"
,
{
config_
.
param
_file
});
op
->
SetAttr
(
"file_path"
,
{
config_
.
param
s_file
()
});
op
->
CheckAttrs
();
op
->
CheckAttrs
();
}
}
...
...
paddle/fluid/inference/api/analysis_predictor_tester.cc
浏览文件 @
68a07328
...
@@ -25,9 +25,9 @@ namespace paddle {
...
@@ -25,9 +25,9 @@ namespace paddle {
using
contrib
::
AnalysisConfig
;
using
contrib
::
AnalysisConfig
;
TEST
(
AnalysisPredictor
,
analysis_off
)
{
TEST
(
AnalysisPredictor
,
analysis_off
)
{
AnalysisConfig
config
(
false
)
;
AnalysisConfig
config
;
config
.
model_dir
=
FLAGS_dirname
;
config
.
SetModel
(
FLAGS_dirname
)
;
config
.
enable_ir_optim
=
false
;
config
.
SwitchIrOptim
(
false
)
;
auto
_predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
auto
_predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
auto
*
predictor
=
static_cast
<
AnalysisPredictor
*>
(
_predictor
.
get
());
auto
*
predictor
=
static_cast
<
AnalysisPredictor
*>
(
_predictor
.
get
());
...
@@ -55,14 +55,14 @@ TEST(AnalysisPredictor, analysis_off) {
...
@@ -55,14 +55,14 @@ TEST(AnalysisPredictor, analysis_off) {
}
}
TEST
(
AnalysisPredictor
,
analysis_on
)
{
TEST
(
AnalysisPredictor
,
analysis_on
)
{
AnalysisConfig
config
;
config
.
SetModel
(
FLAGS_dirname
);
config
.
SwitchIrOptim
(
true
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
AnalysisConfig
config
(
true
);
config
.
EnableUseGpu
(
100
,
0
);
config
.
fraction_of_gpu_memory
=
0.15
;
#else
#else
AnalysisConfig
config
;
config
.
DisableGpu
()
;
#endif
#endif
config
.
model_dir
=
FLAGS_dirname
;
config
.
enable_ir_optim
=
true
;
auto
_predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
auto
_predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
auto
*
predictor
=
static_cast
<
AnalysisPredictor
*>
(
_predictor
.
get
());
auto
*
predictor
=
static_cast
<
AnalysisPredictor
*>
(
_predictor
.
get
());
...
@@ -89,7 +89,8 @@ TEST(AnalysisPredictor, analysis_on) {
...
@@ -89,7 +89,8 @@ TEST(AnalysisPredictor, analysis_on) {
}
}
// compare with NativePredictor
// compare with NativePredictor
auto
naive_predictor
=
CreatePaddlePredictor
<
NativeConfig
>
(
config
);
auto
naive_predictor
=
CreatePaddlePredictor
<
NativeConfig
>
(
config
.
ToNativeConfig
());
std
::
vector
<
PaddleTensor
>
naive_outputs
;
std
::
vector
<
PaddleTensor
>
naive_outputs
;
ASSERT_TRUE
(
naive_predictor
->
Run
(
inputs
,
&
naive_outputs
));
ASSERT_TRUE
(
naive_predictor
->
Run
(
inputs
,
&
naive_outputs
));
ASSERT_EQ
(
naive_outputs
.
size
(),
1UL
);
ASSERT_EQ
(
naive_outputs
.
size
(),
1UL
);
...
@@ -98,9 +99,8 @@ TEST(AnalysisPredictor, analysis_on) {
...
@@ -98,9 +99,8 @@ TEST(AnalysisPredictor, analysis_on) {
TEST
(
AnalysisPredictor
,
ZeroCopy
)
{
TEST
(
AnalysisPredictor
,
ZeroCopy
)
{
AnalysisConfig
config
;
AnalysisConfig
config
;
config
.
model_dir
=
FLAGS_dirname
;
config
.
SetModel
(
FLAGS_dirname
);
config
.
use_feed_fetch_ops
=
false
;
config
.
SwitchUseFeedFetchOps
(
false
);
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
auto
w0
=
predictor
->
GetInputTensor
(
"firstw"
);
auto
w0
=
predictor
->
GetInputTensor
(
"firstw"
);
...
@@ -137,9 +137,9 @@ TEST(AnalysisPredictor, ZeroCopy) {
...
@@ -137,9 +137,9 @@ TEST(AnalysisPredictor, ZeroCopy) {
TEST
(
AnalysisPredictor
,
Clone
)
{
TEST
(
AnalysisPredictor
,
Clone
)
{
AnalysisConfig
config
;
AnalysisConfig
config
;
config
.
model_dir
=
FLAGS_dirname
;
config
.
SetModel
(
FLAGS_dirname
)
;
config
.
use_feed_fetch_ops
=
true
;
config
.
SwitchUseFeedFetchOps
(
true
)
;
config
.
enable_ir_optim
=
true
;
config
.
SwitchIrOptim
(
true
)
;
std
::
vector
<
std
::
unique_ptr
<
PaddlePredictor
>>
predictors
;
std
::
vector
<
std
::
unique_ptr
<
PaddlePredictor
>>
predictors
;
predictors
.
emplace_back
(
CreatePaddlePredictor
(
config
));
predictors
.
emplace_back
(
CreatePaddlePredictor
(
config
));
...
...
paddle/fluid/inference/api/api_anakin_engine.h
浏览文件 @
68a07328
...
@@ -19,8 +19,6 @@ limitations under the License. */
...
@@ -19,8 +19,6 @@ limitations under the License. */
#pragma once
#pragma once
#define WITH_ANAKIN
#include <vector>
#include <vector>
#include "framework/core/net/net.h"
#include "framework/core/net/net.h"
...
...
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
68a07328
...
@@ -288,7 +288,7 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
...
@@ -288,7 +288,7 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
VLOG
(
3
)
<<
"create NativePaddlePredictor"
;
VLOG
(
3
)
<<
"create NativePaddlePredictor"
;
if
(
config
.
use_gpu
)
{
if
(
config
.
use_gpu
)
{
// 1. GPU memeroy
// 1. GPU memeroy
PADDLE_ENFORCE_G
T
(
PADDLE_ENFORCE_G
E
(
config
.
fraction_of_gpu_memory
,
0.
f
,
config
.
fraction_of_gpu_memory
,
0.
f
,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]"
);
"fraction_of_gpu_memory in the config should be set to range (0., 1.]"
);
PADDLE_ENFORCE_GE
(
config
.
device
,
0
,
"Invalid device id %d"
,
config
.
device
);
PADDLE_ENFORCE_GE
(
config
.
device
,
0
,
"Invalid device id %d"
,
config
.
device
);
...
...
paddle/fluid/inference/api/api_impl_tester.cc
浏览文件 @
68a07328
...
@@ -295,7 +295,8 @@ TEST(inference_api_native, image_classification_gpu) {
...
@@ -295,7 +295,8 @@ TEST(inference_api_native, image_classification_gpu) {
#endif
#endif
TEST
(
PassBuilder
,
Delete
)
{
TEST
(
PassBuilder
,
Delete
)
{
contrib
::
AnalysisConfig
config
(
false
);
contrib
::
AnalysisConfig
config
;
config
.
DisableGpu
();
config
.
pass_builder
()
->
DeletePass
(
"attention_lstm_fuse_pass"
);
config
.
pass_builder
()
->
DeletePass
(
"attention_lstm_fuse_pass"
);
const
auto
&
passes
=
config
.
pass_builder
()
->
AllPasses
();
const
auto
&
passes
=
config
.
pass_builder
()
->
AllPasses
();
auto
it
=
std
::
find
(
passes
.
begin
(),
passes
.
end
(),
"attention_lstm_fuse_pass"
);
auto
it
=
std
::
find
(
passes
.
begin
(),
passes
.
end
(),
"attention_lstm_fuse_pass"
);
...
...
paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc
浏览文件 @
68a07328
...
@@ -36,12 +36,11 @@ namespace demo {
...
@@ -36,12 +36,11 @@ namespace demo {
*/
*/
void
Main
()
{
void
Main
()
{
std
::
unique_ptr
<
PaddlePredictor
>
predictor
;
std
::
unique_ptr
<
PaddlePredictor
>
predictor
;
paddle
::
contrib
::
AnalysisConfig
config
(
true
)
;
paddle
::
contrib
::
AnalysisConfig
config
;
config
.
param_file
=
FLAGS_modeldir
+
"/__params__"
;
config
.
EnableUseGpu
(
100
,
0
)
;
config
.
prog_file
=
FLAGS_modeldir
+
"/__model__"
;
config
.
SetModel
(
FLAGS_modeldir
+
"/__params__"
,
config
.
device
=
0
;
FLAGS_modeldir
+
"/__model__"
)
;
config
.
EnableTensorRtEngine
();
config
.
EnableTensorRtEngine
();
config
.
fraction_of_gpu_memory
=
0.1
;
// set by yourself
predictor
=
CreatePaddlePredictor
(
config
);
predictor
=
CreatePaddlePredictor
(
config
);
VLOG
(
3
)
<<
"begin to process data"
;
VLOG
(
3
)
<<
"begin to process data"
;
...
...
paddle/fluid/inference/api/demo_ci/vis_demo.cc
浏览文件 @
68a07328
...
@@ -40,15 +40,14 @@ using contrib::AnalysisConfig;
...
@@ -40,15 +40,14 @@ using contrib::AnalysisConfig;
*/
*/
void
Main
(
bool
use_gpu
)
{
void
Main
(
bool
use_gpu
)
{
std
::
unique_ptr
<
PaddlePredictor
>
predictor
,
analysis_predictor
;
std
::
unique_ptr
<
PaddlePredictor
>
predictor
,
analysis_predictor
;
AnalysisConfig
config
(
use_gpu
);
AnalysisConfig
config
;
config
.
param_file
=
FLAGS_modeldir
+
"/__params__"
;
if
(
use_gpu
)
{
config
.
prog_file
=
FLAGS_modeldir
+
"/__model__"
;
config
.
EnableUseGpu
(
100
,
0
);
config
.
device
=
0
;
if
(
FLAGS_use_gpu
)
{
config
.
fraction_of_gpu_memory
=
0.1
;
// set by yourself
}
}
config
.
SetModel
(
FLAGS_modeldir
+
"/__model__"
,
FLAGS_modeldir
+
"/__params__"
);
predictor
=
CreatePaddlePredictor
<
NativeConfig
>
(
config
);
predictor
=
CreatePaddlePredictor
<
NativeConfig
>
(
config
.
ToNativeConfig
()
);
analysis_predictor
=
CreatePaddlePredictor
(
config
);
analysis_predictor
=
CreatePaddlePredictor
(
config
);
// Just a single batch of data.
// Just a single batch of data.
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
68a07328
...
@@ -34,26 +34,67 @@ class AnalysisPredictor;
...
@@ -34,26 +34,67 @@ class AnalysisPredictor;
namespace
contrib
{
namespace
contrib
{
// NOTE WIP, not stable yet.
// NOTE WIP, not stable yet.
struct
AnalysisConfig
:
public
NativeConfig
{
struct
AnalysisConfig
{
explicit
AnalysisConfig
(
bool
use_gpu
=
false
)
;
AnalysisConfig
()
=
default
;
explicit
AnalysisConfig
(
const
AnalysisConfig
&
other
);
explicit
AnalysisConfig
(
const
AnalysisConfig
&
other
);
explicit
AnalysisConfig
(
AnalysisConfig
&&
other
);
explicit
AnalysisConfig
(
const
std
::
string
&
model_dir
);
explicit
AnalysisConfig
(
const
std
::
string
&
prog_file
,
const
std
::
string
&
params_file
);
// Model path related.
void
SetModel
(
const
std
::
string
&
model_dir
)
{
model_dir_
=
model_dir
;
}
void
SetModel
(
const
std
::
string
&
prog_file_path
,
const
std
::
string
&
params_file_path
);
void
SetProgFile
(
const
std
::
string
&
x
)
{
prog_file_
=
x
;
}
void
SetParamsFile
(
const
std
::
string
&
x
)
{
params_file_
=
x
;
}
const
std
::
string
&
model_dir
()
const
{
return
model_dir_
;
}
const
std
::
string
&
prog_file
()
const
{
return
prog_file_
;
}
const
std
::
string
&
params_file
()
const
{
return
params_file_
;
}
// GPU related.
void
EnableUseGpu
(
uint64_t
memory_pool_init_size_mb
,
int
device_id
=
0
);
void
DisableGpu
();
bool
use_gpu
()
const
{
return
use_gpu_
;
}
int
gpu_device_id
()
const
{
return
device_id_
;
}
int
memory_pool_init_size_mb
()
const
{
return
memory_pool_init_size_mb_
;
}
float
fraction_of_gpu_memory_for_pool
()
const
;
// Determine whether to perform graph optimization.
// Determine whether to perform graph optimization.
bool
enable_ir_optim
=
true
;
void
SwitchIrOptim
(
int
x
=
true
)
{
enable_ir_optim_
=
x
;
}
bool
ir_optim
()
const
{
return
enable_ir_optim_
;
}
// Get a pass builder for customize the passes in IR analysis phase.
void
SwitchUseFeedFetchOps
(
int
x
=
true
)
{
use_feed_fetch_ops_
=
x
;
}
PassStrategy
*
pass_builder
()
const
;
bool
use_feed_fetch_ops_enabled
()
const
{
return
use_feed_fetch_ops_
;
}
// NOT stable yet.
void
SwitchSpecifyInputNames
(
bool
x
=
true
)
{
specify_input_name_
=
x
;
}
bool
use_feed_fetch_ops
{
true
};
bool
specify_input_name
()
const
{
return
specify_input_name_
;
}
void
EnableTensorRtEngine
(
int
workspace_size
=
1
<<
20
,
void
EnableTensorRtEngine
(
int
workspace_size
=
1
<<
20
,
int
max_batch_size
=
1
,
int
min_subgraph_size
=
3
);
int
max_batch_size
=
1
,
int
min_subgraph_size
=
3
);
bool
use_tensorrt
()
const
{
return
use_tensorrt_
;
}
bool
tensorrt_engine_enabled
()
const
{
return
use_tensorrt_
;
}
void
SwitchIrDebug
(
int
x
=
true
)
{
ir_debug_
=
x
;
}
void
EnableMKLDNN
();
void
EnableMKLDNN
();
bool
use_mkldnn
()
const
{
return
use_mkldnn_
;
}
bool
mkldnn_enabled
()
const
{
return
use_mkldnn_
;
}
// Set and get the number of cpu math library threads.
void
SetCpuMathLibraryNumThreads
(
int
cpu_math_library_num_threads
);
int
cpu_math_library_num_threads
()
const
{
return
cpu_math_library_num_threads_
;
}
NativeConfig
ToNativeConfig
()
const
{
NativeConfig
config
;
config
.
model_dir
=
model_dir_
;
config
.
prog_file
=
prog_file_
;
config
.
param_file
=
params_file_
;
config
.
use_gpu
=
use_gpu_
;
config
.
device
=
device_id_
;
config
.
fraction_of_gpu_memory
=
fraction_of_gpu_memory_for_pool
();
config
.
specify_input_name
=
specify_input_name_
;
return
config
;
}
void
SetMKLDNNOp
(
std
::
unordered_set
<
std
::
string
>
op_list
)
{
void
SetMKLDNNOp
(
std
::
unordered_set
<
std
::
string
>
op_list
)
{
mkldnn_enabled_op_types_
=
op_list
;
mkldnn_enabled_op_types_
=
op_list
;
}
}
...
@@ -65,10 +106,29 @@ struct AnalysisConfig : public NativeConfig {
...
@@ -65,10 +106,29 @@ struct AnalysisConfig : public NativeConfig {
friend
class
::
paddle
::
AnalysisPredictor
;
friend
class
::
paddle
::
AnalysisPredictor
;
// NOTE just for developer, not an official API, easily to be broken.
// Get a pass builder for customize the passes in IR analysis phase.
PassStrategy
*
pass_builder
()
const
;
protected:
// Update the config.
void
Update
();
std
::
string
SerializeInfoCache
();
protected:
protected:
// Model pathes.
std
::
string
model_dir_
;
std
::
string
prog_file_
;
std
::
string
params_file_
;
// GPU releated.
bool
use_gpu_
{
false
};
int
device_id_
{
0
};
uint64_t
memory_pool_init_size_mb_
{
100
};
// initial size is 100MB.
// TensorRT releated.
bool
use_tensorrt_
{
false
};
bool
use_tensorrt_
{
false
};
bool
use_mkldnn_
{
false
};
std
::
unordered_set
<
std
::
string
>
mkldnn_enabled_op_types_
;
// For workspace_size, refer it from here:
// For workspace_size, refer it from here:
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting
int
tensorrt_workspace_size_
;
int
tensorrt_workspace_size_
;
...
@@ -82,17 +142,24 @@ struct AnalysisConfig : public NativeConfig {
...
@@ -82,17 +142,24 @@ struct AnalysisConfig : public NativeConfig {
// We set this variable to control the minimum number of nodes in the
// We set this variable to control the minimum number of nodes in the
// subgraph, 3 as default value.
// subgraph, 3 as default value.
int
tensorrt_min_subgraph_size_
{
3
};
int
tensorrt_min_subgraph_size_
{
3
};
std
::
unique_ptr
<
PassStrategy
>
pass_builder_
;
bool
use_mkldnn_
{
false
};
std
::
unordered_set
<
std
::
string
>
mkldnn_enabled_op_types_
;
bool
model_from_memory_
{
false
};
bool
model_from_memory_
{
false
};
};
// Configurations for Anakin engine.
bool
enable_ir_optim_
{
true
};
struct
AnakinConfig
:
public
PaddlePredictor
::
Config
{
bool
use_feed_fetch_ops_
{
true
};
enum
TargetType
{
NVGPU
=
0
,
X86
};
bool
ir_debug_
{
false
};
int
device
;
std
::
string
model_file
;
bool
specify_input_name_
{
false
};
int
max_batch_size
{
-
1
};
TargetType
target_type
;
int
cpu_math_library_num_threads_
{
1
};
// A runtime cache, shouldn't be transferred to others.
std
::
string
serialized_info_cache_
;
mutable
std
::
unique_ptr
<
PassStrategy
>
pass_builder_
;
};
};
}
// namespace contrib
}
// namespace contrib
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
68a07328
...
@@ -26,9 +26,8 @@ limitations under the License. */
...
@@ -26,9 +26,8 @@ limitations under the License. */
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle_api.h" // NOLINT
#ifndef WITH_ANAKIN
#include "paddle_analysis_config.h" // NOLINT
#include "paddle_analysis_config.h" // NOLINT
#else
#include "paddle_api.h" // NOLINT
#ifdef WITH_ANAKIN
#include "paddle_anakin_config.h" // NOLINT
#include "paddle_anakin_config.h" // NOLINT
#endif
#endif
paddle/fluid/inference/api/paddle_pass_builder.h
浏览文件 @
68a07328
...
@@ -62,7 +62,12 @@ class PassStrategy : public PaddlePassBuilder {
...
@@ -62,7 +62,12 @@ class PassStrategy : public PaddlePassBuilder {
// still some CPU kernels running in CPU mode.
// still some CPU kernels running in CPU mode.
virtual
void
EnableMKLDNN
()
=
0
;
virtual
void
EnableMKLDNN
()
=
0
;
bool
use_gpu
()
const
{
return
use_gpu_
;
}
virtual
~
PassStrategy
()
=
default
;
virtual
~
PassStrategy
()
=
default
;
protected:
bool
use_gpu_
{
false
};
};
};
/*
/*
...
@@ -88,6 +93,7 @@ class CpuPassStrategy : public PassStrategy {
...
@@ -88,6 +93,7 @@ class CpuPassStrategy : public PassStrategy {
"conv_eltwiseadd_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"is_test_pass"
,
//
"is_test_pass"
,
//
});
});
use_gpu_
=
false
;
}
}
virtual
~
CpuPassStrategy
()
=
default
;
virtual
~
CpuPassStrategy
()
=
default
;
...
@@ -126,10 +132,14 @@ class GpuPassStrategy : public PassStrategy {
...
@@ -126,10 +132,14 @@ class GpuPassStrategy : public PassStrategy {
"conv_elementwise_add2_act_fuse_pass"
,
//
"conv_elementwise_add2_act_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
});
});
use_gpu_
=
true
;
}
}
GpuPassStrategy
(
const
GpuPassStrategy
&
other
)
GpuPassStrategy
(
const
GpuPassStrategy
&
other
)
:
PassStrategy
(
other
.
AllPasses
())
{}
:
PassStrategy
(
other
.
AllPasses
())
{
use_gpu_
=
true
;
}
void
EnableMKLDNN
()
override
;
void
EnableMKLDNN
()
override
;
...
...
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
68a07328
nv_library
(
tensorrt_engine SRCS engine.cc DEPS
${
GLOB_OPERATOR_DEPS
}
framework_proto device_context
)
nv_library
(
tensorrt_engine SRCS engine.cc DEPS
${
GLOB_OPERATOR_DEPS
}
framework_proto device_context
)
nv_library
(
tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine
)
add_subdirectory
(
plugin
)
add_subdirectory
(
plugin
)
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
0 → 100644
浏览文件 @
68a07328
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/op_teller.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
// Just tell by the op_types.
struct
SimpleOpTypeSetTeller
:
public
Teller
{
SimpleOpTypeSetTeller
()
{}
bool
operator
()(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
)
override
{
return
teller_set
.
count
(
op_type
);
}
private:
std
::
unordered_set
<
std
::
string
>
teller_set
{
{
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"elementwise_add"
,
"elementwise_mul"
,
"dropout"
,
"split"
,
"prelu"
,
"conv2d_transpose"
,
"leaky_relu"
}};
};
bool
OpTeller
::
Tell
(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
)
{
for
(
auto
&
teller
:
tellers_
)
{
if
((
*
teller
)(
op_type
,
desc
))
return
true
;
}
return
false
;
}
OpTeller
::
OpTeller
()
{
tellers_
.
emplace_back
(
new
SimpleOpTypeSetTeller
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/op_teller.h
0 → 100644
浏览文件 @
68a07328
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* Single Op teller definition.
* One can override this and define a more complex tell logic, considerring more
* issues such as op_desc.
*/
struct
Teller
{
virtual
bool
operator
()(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
)
=
0
;
virtual
~
Teller
()
=
default
;
};
/*
* A real example:
*
* struct SomeTeller : public Teller {
* bool operator()(const std::string& op_type,
* const framework::OpDesc& desc) override {
* return op_type == "fc" && desc.Inputs().size() == 2;
* }
*};
*/
/*
* class OpTeller helps to tell whether a fluid
* operator can be transformed to a TensorRT layer.
*/
class
OpTeller
{
public:
static
OpTeller
&
Global
()
{
static
std
::
unique_ptr
<
OpTeller
>
x
(
new
OpTeller
);
return
*
x
;
}
bool
Tell
(
const
std
::
string
&
op_type
,
const
framework
::
OpDesc
&
desc
);
private:
OpTeller
();
private:
std
::
vector
<
std
::
unique_ptr
<
Teller
>>
tellers_
;
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
浏览文件 @
68a07328
...
@@ -165,12 +165,9 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
...
@@ -165,12 +165,9 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
}
}
void
SetConfig
(
contrib
::
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
contrib
::
AnalysisConfig
*
cfg
)
{
cfg
->
prog_file
=
FLAGS_infer_model
+
"/__model__"
;
cfg
->
SetModel
(
FLAGS_infer_model
+
"/__model__"
,
FLAGS_infer_model
+
"/param"
);
cfg
->
param_file
=
FLAGS_infer_model
+
"/param"
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
use_gpu
=
false
;
cfg
->
SwitchIrOptim
(
true
);
cfg
->
device
=
0
;
cfg
->
specify_input_name
=
true
;
cfg
->
enable_ir_optim
=
true
;
}
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
...
paddle/fluid/inference/tests/api/analyzer_lac_tester.cc
浏览文件 @
68a07328
...
@@ -105,11 +105,10 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data,
...
@@ -105,11 +105,10 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data,
}
}
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
cfg
->
model_dir
=
FLAGS_infer_model
;
cfg
->
SetModel
(
FLAGS_infer_model
);
cfg
->
use_gpu
=
false
;
cfg
->
DisableGpu
();
cfg
->
device
=
0
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
specify_input_name
=
true
;
cfg
->
SwitchIrOptim
();
cfg
->
enable_ir_optim
=
true
;
}
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
...
paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc
浏览文件 @
68a07328
...
@@ -76,11 +76,10 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
...
@@ -76,11 +76,10 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
}
}
void
SetConfig
(
contrib
::
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
contrib
::
AnalysisConfig
*
cfg
)
{
cfg
->
model_dir
=
FLAGS_infer_model
;
cfg
->
SetModel
(
FLAGS_infer_model
);
cfg
->
use_gpu
=
false
;
cfg
->
DisableGpu
();
cfg
->
device
=
0
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
specify_input_name
=
true
;
cfg
->
SwitchIrOptim
();
cfg
->
enable_ir_optim
=
true
;
}
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
...
paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
浏览文件 @
68a07328
...
@@ -84,13 +84,12 @@ void SetConfig(contrib::AnalysisConfig *cfg, bool memory_load = false) {
...
@@ -84,13 +84,12 @@ void SetConfig(contrib::AnalysisConfig *cfg, bool memory_load = false) {
cfg
->
SetModelBuffer
(
&
buffer_prog
[
0
],
buffer_prog
.
size
(),
&
buffer_param
[
0
],
cfg
->
SetModelBuffer
(
&
buffer_prog
[
0
],
buffer_prog
.
size
(),
&
buffer_param
[
0
],
buffer_param
.
size
());
buffer_param
.
size
());
}
else
{
}
else
{
cfg
->
prog_file
=
FLAGS_infer_model
+
"/__model__"
;
cfg
->
SetModel
(
FLAGS_infer_model
+
"/__model__"
,
cfg
->
param_file
=
FLAGS_infer_model
+
"/param"
;
FLAGS_infer_model
+
"/param"
)
;
}
}
cfg
->
use_gpu
=
false
;
cfg
->
DisableGpu
();
cfg
->
device
=
0
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
specify_input_name
=
true
;
cfg
->
SwitchIrOptim
();
cfg
->
enable_ir_optim
=
true
;
}
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
...
paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc
浏览文件 @
68a07328
...
@@ -21,12 +21,10 @@ namespace inference {
...
@@ -21,12 +21,10 @@ namespace inference {
namespace
analysis
{
namespace
analysis
{
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
cfg
->
param_file
=
FLAGS_infer_model
+
"/params"
;
cfg
->
SetModel
(
FLAGS_infer_model
+
"/model"
,
FLAGS_infer_model
+
"/params"
);
cfg
->
prog_file
=
FLAGS_infer_model
+
"/model"
;
cfg
->
DisableGpu
();
cfg
->
use_gpu
=
false
;
cfg
->
SwitchIrOptim
();
cfg
->
device
=
0
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
enable_ir_optim
=
true
;
cfg
->
specify_input_name
=
true
;
cfg
->
SetCpuMathLibraryNumThreads
(
FLAGS_paddle_num_threads
);
cfg
->
SetCpuMathLibraryNumThreads
(
FLAGS_paddle_num_threads
);
}
}
...
...
paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
浏览文件 @
68a07328
...
@@ -204,12 +204,10 @@ void PrepareZeroCopyInputs(ZeroCopyTensor *lod_attention_tensor,
...
@@ -204,12 +204,10 @@ void PrepareZeroCopyInputs(ZeroCopyTensor *lod_attention_tensor,
}
}
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
cfg
->
prog_file
=
FLAGS_infer_model
+
"/__model__"
;
cfg
->
SetModel
(
FLAGS_infer_model
+
"/__model__"
,
FLAGS_infer_model
+
"/param"
);
cfg
->
param_file
=
FLAGS_infer_model
+
"/param"
;
cfg
->
DisableGpu
();
cfg
->
use_gpu
=
false
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
device
=
0
;
cfg
->
SwitchIrOptim
();
cfg
->
specify_input_name
=
true
;
cfg
->
enable_ir_optim
=
true
;
}
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
@@ -225,10 +223,10 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...
@@ -225,10 +223,10 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
// Easy for profiling independently.
// Easy for profiling independently.
TEST
(
Analyzer_rnn1
,
profile
)
{
TEST
(
Analyzer_rnn1
,
profile
)
{
contrib
::
AnalysisConfig
cfg
(
false
)
;
contrib
::
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
cfg
.
fraction_of_gpu_memory
=
0.1
;
cfg
.
DisableGpu
()
;
cfg
.
pass_builder
()
->
TurnOn
Debug
();
cfg
.
SwitchIr
Debug
();
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
@@ -293,16 +291,18 @@ TEST(Analyzer_rnn1, multi_thread) {
...
@@ -293,16 +291,18 @@ TEST(Analyzer_rnn1, multi_thread) {
TEST
(
Analyzer_rnn1
,
ZeroCopy
)
{
TEST
(
Analyzer_rnn1
,
ZeroCopy
)
{
AnalysisConfig
config
;
AnalysisConfig
config
;
SetConfig
(
&
config
);
SetConfig
(
&
config
);
config
.
use_feed_fetch_ops
=
false
;
config
.
SwitchUseFeedFetchOps
(
false
)
;
PaddlePlace
place
;
PaddlePlace
place
;
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
config
.
use_feed_fetch_ops
=
true
;
config
.
SwitchUseFeedFetchOps
(
true
);
auto
native_predictor
=
CreatePaddlePredictor
<
NativeConfig
>
(
config
);
auto
native_predictor
=
CreatePaddlePredictor
<
NativeConfig
>
(
config
.
ToNativeConfig
());
config
.
use_feed_fetch_ops
=
true
;
// the analysis predictor needs feed/fetch.
config
.
SwitchUseFeedFetchOps
(
true
);
// the analysis predictor needs feed/fetch.
auto
analysis_predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
auto
analysis_predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
config
);
#define NEW_TENSOR(name__) \
#define NEW_TENSOR(name__) \
...
@@ -362,7 +362,7 @@ TEST(Analyzer_rnn1, ZeroCopy) {
...
@@ -362,7 +362,7 @@ TEST(Analyzer_rnn1, ZeroCopy) {
TEST
(
Analyzer_rnn1
,
ZeroCopyMultiThread
)
{
TEST
(
Analyzer_rnn1
,
ZeroCopyMultiThread
)
{
AnalysisConfig
config
;
AnalysisConfig
config
;
SetConfig
(
&
config
);
SetConfig
(
&
config
);
config
.
use_feed_fetch_ops
=
false
;
config
.
SwitchUseFeedFetchOps
(
false
)
;
#define NEW_TENSOR(name__) \
#define NEW_TENSOR(name__) \
auto name__##_tensor = predictor->GetInputTensor(#name__);
auto name__##_tensor = predictor->GetInputTensor(#name__);
...
...
paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
浏览文件 @
68a07328
...
@@ -105,12 +105,10 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
...
@@ -105,12 +105,10 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
}
}
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
cfg
->
prog_file
=
FLAGS_infer_model
+
"/__model__"
;
cfg
->
SetModel
(
FLAGS_infer_model
+
"/__model__"
,
FLAGS_infer_model
+
"/param"
);
cfg
->
param_file
=
FLAGS_infer_model
+
"/param"
;
cfg
->
DisableGpu
();
cfg
->
use_gpu
=
false
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
device
=
0
;
cfg
->
SwitchIrOptim
();
cfg
->
specify_input_name
=
true
;
cfg
->
enable_ir_optim
=
true
;
}
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
...
paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc
浏览文件 @
68a07328
...
@@ -89,11 +89,10 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
...
@@ -89,11 +89,10 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
}
}
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
cfg
->
model_dir
=
FLAGS_infer_model
;
cfg
->
SetModel
(
FLAGS_infer_model
);
cfg
->
use_gpu
=
false
;
cfg
->
DisableGpu
();
cfg
->
device
=
0
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
specify_input_name
=
true
;
cfg
->
SwitchIrOptim
();
cfg
->
enable_ir_optim
=
true
;
}
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
...
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
浏览文件 @
68a07328
...
@@ -122,12 +122,9 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data) {
...
@@ -122,12 +122,9 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data) {
}
}
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
cfg
->
param_file
=
FLAGS_infer_model
+
"/params"
;
cfg
->
SetModel
(
FLAGS_infer_model
+
"/model"
,
FLAGS_infer_model
+
"/params"
);
cfg
->
prog_file
=
FLAGS_infer_model
+
"/model"
;
cfg
->
DisableGpu
();
cfg
->
use_gpu
=
false
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
device
=
0
;
cfg
->
enable_ir_optim
=
true
;
cfg
->
specify_input_name
=
true
;
cfg
->
pass_builder
()
->
TurnOnDebug
();
cfg
->
pass_builder
()
->
TurnOnDebug
();
cfg
->
SetCpuMathLibraryNumThreads
(
FLAGS_paddle_num_threads
);
cfg
->
SetCpuMathLibraryNumThreads
(
FLAGS_paddle_num_threads
);
}
}
...
...
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
浏览文件 @
68a07328
...
@@ -47,11 +47,10 @@ struct DataReader {
...
@@ -47,11 +47,10 @@ struct DataReader {
};
};
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
cfg
->
model_dir
=
FLAGS_infer_model
;
cfg
->
SetModel
(
FLAGS_infer_model
);
cfg
->
use_gpu
=
false
;
cfg
->
DisableGpu
();
cfg
->
device
=
0
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
specify_input_name
=
true
;
cfg
->
SwitchIrOptim
();
cfg
->
enable_ir_optim
=
true
;
}
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
...
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
浏览文件 @
68a07328
...
@@ -51,12 +51,11 @@ Record ProcessALine(const std::string &line) {
...
@@ -51,12 +51,11 @@ Record ProcessALine(const std::string &line) {
}
}
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
void
SetConfig
(
AnalysisConfig
*
cfg
)
{
cfg
->
param_file
=
FLAGS_infer_model
+
"/__params__"
;
cfg
->
SetModel
(
FLAGS_infer_model
+
"/__model__"
,
cfg
->
prog_file
=
FLAGS_infer_model
+
"/__model__"
;
FLAGS_infer_model
+
"/__params__"
);
cfg
->
use_gpu
=
false
;
cfg
->
DisableGpu
();
cfg
->
device
=
0
;
cfg
->
SwitchIrDebug
();
cfg
->
enable_ir_optim
=
true
;
cfg
->
SwitchSpecifyInputNames
();
cfg
->
specify_input_name
=
true
;
// TODO(TJ): fix fusion gru
// TODO(TJ): fix fusion gru
cfg
->
pass_builder
()
->
DeletePass
(
"fc_gru_fuse_pass"
);
cfg
->
pass_builder
()
->
DeletePass
(
"fc_gru_fuse_pass"
);
}
}
...
...
paddle/fluid/inference/tests/api/config_printer.h
浏览文件 @
68a07328
...
@@ -64,19 +64,23 @@ std::ostream &operator<<(std::ostream &os,
...
@@ -64,19 +64,23 @@ std::ostream &operator<<(std::ostream &os,
num_spaces
++
;
num_spaces
++
;
os
<<
*
reinterpret_cast
<
const
NativeConfig
*>
(
&
config
);
os
<<
*
reinterpret_cast
<
const
NativeConfig
*>
(
&
config
);
if
(
!
config
.
model_from_memory
())
{
if
(
!
config
.
model_from_memory
())
{
os
<<
GenSpaces
(
num_spaces
)
<<
"prog_file: "
<<
config
.
prog_file
<<
"
\n
"
;
os
<<
GenSpaces
(
num_spaces
)
<<
"prog_file: "
<<
config
.
prog_file
()
<<
"
\n
"
;
os
<<
GenSpaces
(
num_spaces
)
<<
"param_file: "
<<
config
.
param_file
<<
"
\n
"
;
os
<<
GenSpaces
(
num_spaces
)
<<
"param_file: "
<<
config
.
params_file
()
<<
"
\n
"
;
}
else
{
}
else
{
os
<<
GenSpaces
(
num_spaces
)
os
<<
GenSpaces
(
num_spaces
)
<<
"prog_file and param_file: load from memory
\n
"
;
<<
"prog_file and param_file: load from memory
\n
"
;
}
}
os
<<
GenSpaces
(
num_spaces
)
<<
"enable_ir_optim: "
<<
config
.
enable_ir_optim
os
<<
GenSpaces
(
num_spaces
)
<<
"enable_ir_optim: "
<<
config
.
ir_optim
()
<<
"
\n
"
;
<<
"
\n
"
;
os
<<
GenSpaces
(
num_spaces
)
<<
"enable_ir_optim: "
<<
config
.
ir_optim
()
<<
"
\n
"
;
os
<<
GenSpaces
(
num_spaces
)
<<
"use_feed_fetch_ops: "
<<
config
.
use_feed_fetch_ops_enabled
()
<<
"
\n
"
;
os
<<
GenSpaces
(
num_spaces
)
os
<<
GenSpaces
(
num_spaces
)
<<
"use_
feed_fetch_ops: "
<<
config
.
use_feed_fetch_ops
<<
"
\n
"
;
<<
"use_
tensorrt: "
<<
config
.
tensorrt_engine_enabled
()
<<
"
\n
"
;
os
<<
GenSpaces
(
num_spaces
)
<<
"use_
tensorrt: "
<<
config
.
use_tensorrt
()
os
<<
GenSpaces
(
num_spaces
)
<<
"use_
mkldnn: "
<<
config
.
mkldnn_enabled
()
<<
"
\n
"
;
<<
"
\n
"
;
os
<<
GenSpaces
(
num_spaces
)
<<
"use_mkldnn: "
<<
config
.
use_mkldnn
()
<<
"
\n
"
;
num_spaces
--
;
num_spaces
--
;
os
<<
GenSpaces
(
num_spaces
)
<<
"}
\n
"
;
os
<<
GenSpaces
(
num_spaces
)
<<
"}
\n
"
;
return
os
;
return
os
;
...
...
paddle/fluid/inference/tests/api/tester_helper.h
浏览文件 @
68a07328
...
@@ -328,7 +328,10 @@ void CompareNativeAndAnalysis(
...
@@ -328,7 +328,10 @@ void CompareNativeAndAnalysis(
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
)
{
const
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
&
inputs
)
{
PrintConfig
(
config
,
true
);
PrintConfig
(
config
,
true
);
std
::
vector
<
PaddleTensor
>
native_outputs
,
analysis_outputs
;
std
::
vector
<
PaddleTensor
>
native_outputs
,
analysis_outputs
;
TestOneThreadPrediction
(
config
,
inputs
,
&
native_outputs
,
false
);
const
auto
*
analysis_config
=
reinterpret_cast
<
const
contrib
::
AnalysisConfig
*>
(
config
);
auto
native_config
=
analysis_config
->
ToNativeConfig
();
TestOneThreadPrediction
(
&
native_config
,
inputs
,
&
native_outputs
,
false
);
TestOneThreadPrediction
(
config
,
inputs
,
&
analysis_outputs
,
true
);
TestOneThreadPrediction
(
config
,
inputs
,
&
analysis_outputs
,
true
);
CompareResult
(
analysis_outputs
,
native_outputs
);
CompareResult
(
analysis_outputs
,
native_outputs
);
}
}
...
...
paddle/fluid/inference/tests/api/trt_models_tester.cc
浏览文件 @
68a07328
...
@@ -46,22 +46,20 @@ void SetConfig<contrib::AnalysisConfig>(contrib::AnalysisConfig* config,
...
@@ -46,22 +46,20 @@ void SetConfig<contrib::AnalysisConfig>(contrib::AnalysisConfig* config,
std
::
string
model_dir
,
bool
use_gpu
,
std
::
string
model_dir
,
bool
use_gpu
,
bool
use_tensorrt
,
int
batch_size
)
{
bool
use_tensorrt
,
int
batch_size
)
{
if
(
!
FLAGS_prog_filename
.
empty
()
&&
!
FLAGS_param_filename
.
empty
())
{
if
(
!
FLAGS_prog_filename
.
empty
()
&&
!
FLAGS_param_filename
.
empty
())
{
config
->
prog_file
=
model_dir
+
"/"
+
FLAGS_prog_filename
;
config
->
SetModel
(
model_dir
+
"/"
+
FLAGS_prog_filename
,
config
->
param_file
=
model_dir
+
"/"
+
FLAGS_param_filename
;
model_dir
+
"/"
+
FLAGS_param_filename
)
;
}
else
{
}
else
{
config
->
model_dir
=
model_dir
;
config
->
SetModel
(
model_dir
)
;
}
}
if
(
use_gpu
)
{
if
(
use_gpu
)
{
config
->
use_gpu
=
true
;
config
->
EnableUseGpu
(
100
,
0
);
config
->
device
=
0
;
config
->
fraction_of_gpu_memory
=
0.15
;
if
(
use_tensorrt
)
{
if
(
use_tensorrt
)
{
config
->
EnableTensorRtEngine
(
1
<<
10
,
batch_size
);
config
->
EnableTensorRtEngine
(
1
<<
10
,
batch_size
);
config
->
pass_builder
()
->
DeletePass
(
"conv_bn_fuse_pass"
);
config
->
pass_builder
()
->
DeletePass
(
"conv_bn_fuse_pass"
);
config
->
pass_builder
()
->
DeletePass
(
"fc_fuse_pass"
);
config
->
pass_builder
()
->
DeletePass
(
"fc_fuse_pass"
);
config
->
pass_builder
()
->
TurnOnDebug
();
config
->
pass_builder
()
->
TurnOnDebug
();
}
else
{
}
else
{
config
->
enable_ir_optim
=
true
;
config
->
SwitchIrOptim
()
;
}
}
}
}
}
}
...
@@ -77,7 +75,8 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
...
@@ -77,7 +75,8 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
PaddleTensor
>
outputs
;
if
(
use_analysis
||
use_tensorrt
)
{
if
(
use_analysis
||
use_tensorrt
)
{
contrib
::
AnalysisConfig
config
(
true
);
contrib
::
AnalysisConfig
config
;
config
.
EnableUseGpu
(
100
,
0
);
config
.
pass_builder
()
->
TurnOnDebug
();
config
.
pass_builder
()
->
TurnOnDebug
();
SetConfig
<
contrib
::
AnalysisConfig
>
(
&
config
,
model_dir
,
true
,
use_tensorrt
,
SetConfig
<
contrib
::
AnalysisConfig
>
(
&
config
,
model_dir
,
true
,
use_tensorrt
,
FLAGS_batch_size
);
FLAGS_batch_size
);
...
@@ -109,7 +108,8 @@ void compare(std::string model_dir, bool use_tensorrt) {
...
@@ -109,7 +108,8 @@ void compare(std::string model_dir, bool use_tensorrt) {
&
native_outputs
,
false
);
&
native_outputs
,
false
);
std
::
vector
<
PaddleTensor
>
analysis_outputs
;
std
::
vector
<
PaddleTensor
>
analysis_outputs
;
contrib
::
AnalysisConfig
analysis_config
(
true
);
contrib
::
AnalysisConfig
analysis_config
;
analysis_config
.
EnableUseGpu
(
50
,
0
);
SetConfig
<
contrib
::
AnalysisConfig
>
(
&
analysis_config
,
model_dir
,
true
,
SetConfig
<
contrib
::
AnalysisConfig
>
(
&
analysis_config
,
model_dir
,
true
,
use_tensorrt
,
FLAGS_batch_size
);
use_tensorrt
,
FLAGS_batch_size
);
TestOneThreadPrediction
(
TestOneThreadPrediction
(
...
@@ -154,9 +154,9 @@ TEST(TensorRT_mobilenet, analysis) {
...
@@ -154,9 +154,9 @@ TEST(TensorRT_mobilenet, analysis) {
TEST
(
AnalysisPredictor
,
use_gpu
)
{
TEST
(
AnalysisPredictor
,
use_gpu
)
{
std
::
string
model_dir
=
FLAGS_infer_model
+
"/"
+
"mobilenet"
;
std
::
string
model_dir
=
FLAGS_infer_model
+
"/"
+
"mobilenet"
;
AnalysisConfig
config
(
true
)
;
AnalysisConfig
config
;
config
.
model_dir
=
model_dir
;
config
.
EnableUseGpu
(
100
,
0
)
;
config
.
fraction_of_gpu_memory
=
0.15
;
config
.
SetModel
(
model_dir
)
;
config
.
pass_builder
()
->
TurnOnDebug
();
config
.
pass_builder
()
->
TurnOnDebug
();
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
inputs_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
inputs_all
;
...
...
paddle/fluid/operators/conv_mkldnn_op.cc
浏览文件 @
68a07328
...
@@ -319,6 +319,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -319,6 +319,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
vector
<
int
>
dilations
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
std
::
vector
<
int
>
dilations
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
bool
fuse_relu
=
ctx
.
Attr
<
bool
>
(
"fuse_relu"
);
bool
force_fp32_output
=
ctx
.
Attr
<
bool
>
(
"force_fp32_output"
);
bool
force_fp32_output
=
ctx
.
Attr
<
bool
>
(
"force_fp32_output"
);
bool
is_conv3d
=
strides
.
size
()
==
3U
;
bool
is_conv3d
=
strides
.
size
()
==
3U
;
...
@@ -329,6 +331,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -329,6 +331,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dilations
[
2
]
==
1
dilations
[
2
]
==
1
:
dilations
.
size
()
==
2
&&
dilations
[
0
]
==
1
&&
dilations
[
1
]
==
1
,
:
dilations
.
size
()
==
2
&&
dilations
[
0
]
==
1
&&
dilations
[
1
]
==
1
,
"dilation in convolution is not implemented yet"
);
"dilation in convolution is not implemented yet"
);
PADDLE_ENFORCE
(
is_conv3d
!=
true
,
"int8 does not support conv3d currently"
);
PADDLE_ENFORCE
(
is_conv3d
!=
true
,
"int8 does not support conv3d currently"
);
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
input_data
=
input
->
data
<
T
>
();
...
@@ -340,15 +343,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -340,15 +343,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
GetWeightsTz
(
weights_tz
,
g
,
is_conv3d
);
GetWeightsTz
(
weights_tz
,
g
,
is_conv3d
);
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
mkldnn
::
memory
::
data_type
src_dt
=
paddle
::
framework
::
ToMKLDNNDataType
(
input
->
type
());
auto
dst_dt
=
fuse_relu
?
paddle
::
framework
::
ToMKLDNNDataType
(
framework
::
DataTypeTrait
<
uint8_t
>::
DataType
)
:
paddle
::
framework
::
ToMKLDNNDataType
(
framework
::
DataTypeTrait
<
int8_t
>::
DataType
);
if
(
force_fp32_output
)
{
dst_dt
=
paddle
::
framework
::
ToMKLDNNDataType
(
framework
::
DataTypeTrait
<
float
>::
DataType
);
}
// Get unique name for storing MKLDNN primitives
// Get unique name for storing MKLDNN primitives
std
::
string
key
;
std
::
string
key
;
key
.
reserve
(
MaxKeyLength
);
key
.
reserve
(
MaxKeyLength
);
mkldnn
::
memory
::
data_type
src_dt
=
paddle
::
framework
::
ToMKLDNNDataType
(
input
->
type
());
platform
::
ConvMKLDNNHandler
::
AppendKey
(
platform
::
ConvMKLDNNHandler
::
AppendKey
(
&
key
,
src_tz
,
weights_tz
,
strides
,
paddings
,
dilations
,
groups
,
src_dt
,
&
key
,
src_tz
,
weights_tz
,
strides
,
paddings
,
dilations
,
groups
,
src_dt
,
input
->
format
(),
ctx
.
op
().
Output
(
"Output"
));
input
->
format
(),
dst_dt
,
ctx
.
op
().
Output
(
"Output"
));
const
std
::
string
key_conv_pd
=
key
+
"@conv_pd"
;
const
std
::
string
key_conv_pd
=
key
+
"@conv_pd"
;
std
::
shared_ptr
<
mkldnn
::
convolution_forward
>
conv_p
=
nullptr
;
std
::
shared_ptr
<
mkldnn
::
convolution_forward
>
conv_p
=
nullptr
;
...
@@ -413,13 +425,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -413,13 +425,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform
::
MKLDNNMemDesc
(
src_tz
,
src_dt
,
chosen_memory_format
);
platform
::
MKLDNNMemDesc
(
src_tz
,
src_dt
,
chosen_memory_format
);
auto
weights_md
=
platform
::
MKLDNNMemDesc
(
auto
weights_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
memory
::
data_type
::
s8
,
chosen_memory_format
);
weights_tz
,
memory
::
data_type
::
s8
,
chosen_memory_format
);
auto
dst_dt
=
force_fp32_output
?
paddle
::
framework
::
ToMKLDNNDataType
(
framework
::
DataTypeTrait
<
float
>::
DataType
)
:
paddle
::
framework
::
ToMKLDNNDataType
(
framework
::
DataTypeTrait
<
int8_t
>::
DataType
);
auto
dst_md
=
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
dst_dt
,
chosen_memory_format
);
platform
::
MKLDNNMemDesc
(
dst_tz
,
dst_dt
,
chosen_memory_format
);
// create a conv primitive descriptor and save it for usage in backward
// create a conv primitive descriptor and save it for usage in backward
...
@@ -429,11 +434,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -429,11 +434,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory
::
format
::
x
);
memory
::
format
::
x
);
conv_pd
=
ConvFwdPrimitiveDesc
(
src_md
,
weights_md
,
bias_md
,
dst_md
,
conv_pd
=
ConvFwdPrimitiveDesc
(
src_md
,
weights_md
,
bias_md
,
dst_md
,
strides
,
paddings
,
mkldnn_engine
,
strides
,
paddings
,
mkldnn_engine
,
output_shift_scale
,
is_test
);
fuse_relu
,
output_shift_scale
,
is_test
);
}
else
{
}
else
{
conv_pd
=
conv_pd
=
ConvFwdPrimitiveDesc
(
src_md
,
weights_md
,
dst_md
,
strides
,
ConvFwdPrimitiveDesc
(
src_md
,
weights_md
,
dst_md
,
strides
,
paddings
,
paddings
,
mkldnn_engine
,
fuse_relu
,
mkldnn_engine
,
output_shift_scale
,
is_test
);
output_shift_scale
,
is_test
);
}
}
// Save conv_pd/src_memory/weights_memory for backward pass
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx
.
SetBlob
(
key_conv_pd
,
conv_pd
);
dev_ctx
.
SetBlob
(
key_conv_pd
,
conv_pd
);
...
@@ -459,7 +464,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -459,7 +464,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mask_reorder
);
mask_reorder
);
if
(
!
force_fp32_output
)
{
if
(
!
force_fp32_output
)
{
dst_memory_p
=
platform
::
SetDstMemory
<
int8_t
>
(
ctx
,
output
,
handler
);
if
(
fuse_relu
)
{
dst_memory_p
=
platform
::
SetDstMemory
<
uint8_t
>
(
ctx
,
output
,
handler
);
}
else
{
dst_memory_p
=
platform
::
SetDstMemory
<
int8_t
>
(
ctx
,
output
,
handler
);
}
}
else
{
}
else
{
dst_memory_p
=
platform
::
SetDstMemory
<
float
>
(
ctx
,
output
,
handler
);
dst_memory_p
=
platform
::
SetDstMemory
<
float
>
(
ctx
,
output
,
handler
);
}
}
...
@@ -518,8 +527,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -518,8 +527,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn_engine
,
key
));
mkldnn_engine
,
key
));
}
}
if
(
!
force_fp32_output
)
{
if
(
!
force_fp32_output
)
{
dst_memory_p
=
if
(
fuse_relu
)
{
platform
::
SetDstMemoryHandler
<
int8_t
>
(
ctx
,
output
,
handler
);
dst_memory_p
=
platform
::
SetDstMemoryHandler
<
uint8_t
>
(
ctx
,
output
,
handler
);
}
else
{
dst_memory_p
=
platform
::
SetDstMemoryHandler
<
int8_t
>
(
ctx
,
output
,
handler
);
}
}
else
{
}
else
{
dst_memory_p
=
dst_memory_p
=
platform
::
SetDstMemoryHandler
<
float
>
(
ctx
,
output
,
handler
);
platform
::
SetDstMemoryHandler
<
float
>
(
ctx
,
output
,
handler
);
...
@@ -563,11 +577,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -563,11 +577,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
}
mkldnn
::
primitive_attr
CreatePostOps
(
mkldnn
::
primitive_attr
CreatePostOps
(
const
std
::
vector
<
float
>
output_shift_scale
)
const
{
bool
fuse_relu
,
const
std
::
vector
<
float
>
output_shift_scale
)
const
{
mkldnn
::
primitive_attr
conv_attr
;
mkldnn
::
primitive_attr
conv_attr
;
mkldnn
::
post_ops
post_operations
;
mkldnn
::
post_ops
post_operations
;
int
mask
=
output_shift_scale
.
size
()
>
1
?
1
<<
1
:
0
;
int
mask
=
output_shift_scale
.
size
()
>
1
?
1
<<
1
:
0
;
conv_attr
.
set_output_scales
(
mask
,
output_shift_scale
);
conv_attr
.
set_output_scales
(
mask
,
output_shift_scale
);
if
(
fuse_relu
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
negative_slope
=
0.0
f
;
constexpr
float
placeholder
=
1.0
f
;
// beta
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_relu
,
negative_slope
,
placeholder
);
}
conv_attr
.
set_post_ops
(
post_operations
);
conv_attr
.
set_post_ops
(
post_operations
);
return
conv_attr
;
return
conv_attr
;
}
}
...
@@ -600,7 +621,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -600,7 +621,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ConvFwdPrimitiveDesc
(
const
memory
::
desc
&
src
,
const
memory
::
desc
&
weights
,
ConvFwdPrimitiveDesc
(
const
memory
::
desc
&
src
,
const
memory
::
desc
&
weights
,
const
memory
::
desc
&
dst
,
const
std
::
vector
<
int
>&
strides
,
const
memory
::
desc
&
dst
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
mkldnn
::
engine
&
engine
,
const
mkldnn
::
engine
&
engine
,
const
bool
fuse_relu
,
const
std
::
vector
<
float
>
output_shift_scale
,
const
std
::
vector
<
float
>
output_shift_scale
,
bool
is_test
)
const
{
bool
is_test
)
const
{
memory
::
dims
stride_dims
=
{
strides
[
0
],
strides
[
1
]};
memory
::
dims
stride_dims
=
{
strides
[
0
],
strides
[
1
]};
...
@@ -613,7 +634,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -613,7 +634,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation
,
mkldnn
::
convolution_direct
,
src
,
weights
,
dst
,
stride_dims
,
propagation
,
mkldnn
::
convolution_direct
,
src
,
weights
,
dst
,
stride_dims
,
padding_dims
,
padding_dims
,
mkldnn
::
padding_kind
::
zero
);
padding_dims
,
padding_dims
,
mkldnn
::
padding_kind
::
zero
);
mkldnn
::
primitive_attr
conv_attr
=
CreatePostOps
(
output_shift_scale
);
mkldnn
::
primitive_attr
conv_attr
=
CreatePostOps
(
fuse_relu
,
output_shift_scale
);
auto
p_conv_pd
=
new
mkldnn
::
convolution_forward
::
primitive_desc
(
auto
p_conv_pd
=
new
mkldnn
::
convolution_forward
::
primitive_desc
(
conv_desc
,
conv_attr
,
engine
);
conv_desc
,
conv_attr
,
engine
);
...
@@ -652,7 +674,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -652,7 +674,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const
memory
::
desc
&
bias
,
const
memory
::
desc
&
dst
,
const
memory
::
desc
&
bias
,
const
memory
::
desc
&
dst
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
mkldnn
::
engine
&
engine
,
const
mkldnn
::
engine
&
engine
,
const
bool
fuse_relu
,
const
std
::
vector
<
float
>
output_shift_scale
,
const
std
::
vector
<
float
>
output_shift_scale
,
bool
is_test
)
const
{
bool
is_test
)
const
{
memory
::
dims
stride_dims
=
{
strides
[
0
],
strides
[
1
]};
memory
::
dims
stride_dims
=
{
strides
[
0
],
strides
[
1
]};
...
@@ -665,7 +687,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -665,7 +687,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation
,
mkldnn
::
convolution_direct
,
src
,
weights
,
bias
,
dst
,
propagation
,
mkldnn
::
convolution_direct
,
src
,
weights
,
bias
,
dst
,
stride_dims
,
padding_dims
,
padding_dims
,
mkldnn
::
padding_kind
::
zero
);
stride_dims
,
padding_dims
,
padding_dims
,
mkldnn
::
padding_kind
::
zero
);
mkldnn
::
primitive_attr
conv_attr
=
CreatePostOps
(
output_shift_scale
);
mkldnn
::
primitive_attr
conv_attr
=
CreatePostOps
(
fuse_relu
,
output_shift_scale
);
auto
p_conv_pd
=
new
mkldnn
::
convolution_forward
::
primitive_desc
(
auto
p_conv_pd
=
new
mkldnn
::
convolution_forward
::
primitive_desc
(
conv_desc
,
conv_attr
,
engine
);
conv_desc
,
conv_attr
,
engine
);
...
...
paddle/fluid/operators/linear_chain_crf_op.cc
浏览文件 @
68a07328
...
@@ -230,10 +230,12 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
...
@@ -230,10 +230,12 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Emission"
)))
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Emission"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Emission"
),
emission_exps_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Emission"
),
emission_exps_dims
);
ctx
->
ShareLoD
(
"Emission"
,
framework
::
GradVarName
(
"Emission"
));
}
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Transition"
)))
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Transition"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Transition"
),
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Transition"
),
transition_exps_dims
);
transition_exps_dims
);
ctx
->
ShareLoD
(
"Transition"
,
framework
::
GradVarName
(
"Transition"
));
}
}
}
}
...
...
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
68a07328
...
@@ -62,27 +62,19 @@ struct CUBlas<float> {
...
@@ -62,27 +62,19 @@ struct CUBlas<float> {
cudaDataType_t
Atype
,
int
lda
,
const
void
*
B
,
cudaDataType_t
Atype
,
int
lda
,
const
void
*
B
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
void
*
C
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
)
{
cudaDataType_t
Ctype
,
int
ldc
)
{
// Because the gcc 4.8 doesn't expand template parameter pack that
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// appears in a lambda-expression, I can not use template parameter pack
// here.
// here.
auto
cublas_call
=
[
&
]()
{
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
VLOG
(
5
)
<<
"use_tensor_op_math: "
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
platform
::
TensorCoreAvailable
()
?
"True"
:
"False"
);
<<
(
dev_ctx
->
tensor_core_available
()
?
"True"
:
"False"
);
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
));
beta
,
C
,
Ctype
,
ldc
));
});
#else
#else
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
#endif
#endif
}
}
};
};
...
@@ -170,32 +162,24 @@ struct CUBlas<platform::float16> {
...
@@ -170,32 +162,24 @@ struct CUBlas<platform::float16> {
cudaDataType_t
Btype
,
int
ldb
,
const
void
*
beta
,
void
*
C
,
cudaDataType_t
Btype
,
int
ldb
,
const
void
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
,
cudaDataType_t
Ctype
,
int
ldc
,
cudaDataType_t
computeType
)
{
cudaDataType_t
computeType
)
{
auto
cublas_call
=
[
&
]()
{
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
#if CUDA_VERSION >= 9000
bool
use_tensor_op_math
=
platform
::
TensorCoreA
vailable
();
bool
use_tensor_op_math
=
dev_ctx
->
tensor_core_a
vailable
();
if
(
use_tensor_op_math
)
{
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
#endif // CUDA_VERSION >= 9000
#endif // CUDA_VERSION >= 9000
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
});
#else
#else
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
#endif
#endif
}
}
};
};
...
@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
...
@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F
,
N
);
CUDA_R_32F
,
N
);
}
else
{
}
else
{
#endif // CUDA_VERSION >= 8000
#endif // CUDA_VERSION >= 8000
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
);
lda
,
&
beta
,
C
,
N
);
});
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
}
}
...
@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
...
@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
);
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
);
#else
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
&
h_beta
,
h_C
,
N
);
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
);
});
#endif // CUDA_VERSION >= 8000
#endif // CUDA_VERSION >= 8000
}
}
...
@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
...
@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
}
else
{
}
else
{
#endif // CUDA_VERSION >= 8000
#endif // CUDA_VERSION >= 8000
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
});
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
}
}
...
@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
...
@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
ldc
);
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
});
}
}
template
<
>
template
<
>
template
<
typename
T
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
void
Blas
<
platform
::
CUDADeviceContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
T
*
y
)
const
{
T
*
y
)
const
{
CUBlas
<
T
>::
AXPY
(
context_
.
cublas_handle
(),
n
,
&
alpha
,
x
,
1
,
y
,
1
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
AXPY
(
handle
,
n
,
&
alpha
,
x
,
1
,
y
,
1
);
});
}
}
template
<
>
template
<
>
...
@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
...
@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T
beta
,
T
*
C
)
const
{
T
beta
,
T
*
C
)
const
{
cublasOperation_t
cuTransA
=
!
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransA
=
!
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBlas
<
T
>::
GEMV
(
context_
.
cublas_handle
(),
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
&
beta
,
C
,
1
);
CUBlas
<
T
>::
GEMV
(
handle
,
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
&
beta
,
C
,
1
);
});
}
}
template
<
>
template
<
>
...
@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
...
@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010
#if CUDA_VERSION >= 9010
if
(
FLAGS_enable_cublas_tensor_op_math
&&
std
::
is_same
<
T
,
float
>::
value
)
{
if
(
FLAGS_enable_cublas_tensor_op_math
&&
std
::
is_same
<
T
,
float
>::
value
)
{
auto
cublas_call
=
[
&
]()
{
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
bool
use_tensor_op_math
=
context_
.
tensor_core_available
()
;
bool
use_tensor_op_math
=
platform
::
TensorCoreAvailable
();
if
(
use_tensor_op_math
)
{
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
context_
.
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
CUDA_R_32F
,
ldb
,
CUDA_R_32F
,
ldb
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
ldc
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
};
});
auto
&
dev_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
dev_ctx
.
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
}
else
{
}
else
{
#endif // CUDA_VERSION >= 9010
#endif // CUDA_VERSION >= 9010
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
strideB
,
A
,
lda
,
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
strideA
,
&
beta
,
C
,
ldc
,
strideC
,
batchCount
);
B
,
ldb
,
strideB
,
A
,
lda
,
strideA
,
&
beta
,
C
,
ldc
,
strideC
,
batchCount
);
});
#if CUDA_VERSION >= 9010
#if CUDA_VERSION >= 9010
}
}
...
...
paddle/fluid/operators/optimizers/adam_op.h
浏览文件 @
68a07328
...
@@ -424,16 +424,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
...
@@ -424,16 +424,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
}
}
}
}
framework
::
SelectedRows
cpu_grad_merge
;
const
framework
::
SelectedRows
*
grad_merge_ptr
;
const
framework
::
SelectedRows
*
grad_merge_ptr
;
if
(
is_strict_sorted
)
{
if
(
is_strict_sorted
)
{
grad_merge_ptr
=
&
grad
;
grad_merge_ptr
=
&
grad
;
}
else
{
}
else
{
// merge duplicated rows if any.
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
// The rows of grad_merge have been sorted inside MergeAdd functor
framework
::
SelectedRows
*
grad_merge_var
;
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
auto
*
grad_merge_var
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
.
Var
()
grad_merge_var
=
&
cpu_grad_merge
;
->
GetMutable
<
framework
::
SelectedRows
>
();
}
else
{
// FIXME(qiao): GPU also need to fix this
grad_merge_var
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
()
->
GetMutable
<
framework
::
SelectedRows
>
();
}
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
grad_merge_var
,
true
);
grad_merge_var
,
true
);
grad_merge_ptr
=
grad_merge_var
;
grad_merge_ptr
=
grad_merge_var
;
...
...
paddle/fluid/
framework/details/multi_devices_graph_check_pass
.h
→
paddle/fluid/
platform/cuda_helper
.h
浏览文件 @
68a07328
// Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 201
9
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -14,25 +14,45 @@
...
@@ -14,25 +14,45 @@
#pragma once
#pragma once
#include
"paddle/fluid/framework/details/multi_devices_helper.h"
#include
<mutex> // NOLINT
#include <string>
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum
cublasMath_t
{
CUBLAS_DEFAULT_MATH
=
0
};
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
platform
{
namespace
details
{
class
CublasHandleHolder
{
class
SSAGraghBuilderWithChecker
:
public
ir
::
Pass
{
public:
protected:
CublasHandleHolder
(
cudaStream_t
stream
,
cublasMath_t
math_type
)
{
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
handle_
));
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
handle_
,
stream
));
PADDLE_ENFORCE
(
IsValidGraph
(
graph
.
get
()));
#if CUDA_VERSION >= 9000
return
graph
;
if
(
math_type
==
CUBLAS_TENSOR_OP_MATH
)
{
PADDLE_ENFORCE
(
dynload
::
cublasSetMathMode
(
handle_
,
CUBLAS_TENSOR_OP_MATH
));
}
#endif
}
~
CublasHandleHolder
()
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
handle_
));
}
template
<
typename
Callback
>
inline
void
Call
(
Callback
&&
callback
)
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
(
handle_
);
}
}
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
private:
DISABLE_COPY_AND_ASSIGN
(
CublasHandleHolder
);
cublasHandle_t
handle_
;
mutable
std
::
mutex
mtx_
;
};
};
}
// namespace details
}
// namespace platform
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/platform/device_context.cc
浏览文件 @
68a07328
...
@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
...
@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
cublas_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_DEFAULT_MATH
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
if
(
TensorCoreAvailable
())
{
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_TENSOR_OP_MATH
));
#endif
}
if
(
dynload
::
HasCUDNN
())
{
if
(
dynload
::
HasCUDNN
())
{
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
}
}
...
@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
...
@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
place_
.
device
);
Wait
();
Wait
();
WaitStreamCallback
();
WaitStreamCallback
();
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
cublas_handle_
.
reset
();
cublas_tensor_core_handle_
.
reset
();
eigen_stream_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
...
@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return
eigen_device_
.
get
();
return
eigen_device_
.
get
();
}
}
cublasHandle_t
CUDADeviceContext
::
cublas_hand
le
()
const
{
bool
CUDADeviceContext
::
tensor_core_availab
le
()
const
{
return
cublas_
handle_
;
return
cublas_
tensor_core_handle_
!=
nullptr
;
}
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
68a07328
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
...
@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
...
@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
};
};
#if CUDA_VERSION >= 9000
class
ScopedCublasMathMode
{
public:
ScopedCublasMathMode
(
cublasHandle_t
handle
,
cublasMath_t
new_math_mode
)
:
handle_
(
handle
)
{
need_reset
=
false
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGetMathMode
(
handle_
,
&
old_math_mode_
),
"Failed to get old cublas math mode"
);
if
(
old_math_mode_
!=
new_math_mode
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
new_math_mode
),
"Failed to set old cublas math mode"
);
need_reset
=
true
;
}
}
~
ScopedCublasMathMode
()
{
if
(
need_reset
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
old_math_mode_
),
"Failed to set old cublas math mode"
);
}
}
private:
cublasHandle_t
handle_
;
cublasMath_t
old_math_mode_
;
bool
need_reset
;
};
#endif
class
CUDADeviceContext
:
public
DeviceContext
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
public:
explicit
CUDADeviceContext
(
CUDAPlace
place
);
explicit
CUDADeviceContext
(
CUDAPlace
place
);
...
@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */
/*! \brief Return eigen device in the device context. */
Eigen
::
GpuDevice
*
eigen_device
()
const
;
Eigen
::
GpuDevice
*
eigen_device
()
const
;
/*! \brief Return cublas handle in the device context. */
/*! \brief Call cublas function safely. */
cublasHandle_t
cublas_handle
()
const
;
template
<
typename
Callback
>
inline
void
CublasCall
(
Callback
&&
callback
)
const
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
/*! \brief Check whether tensor core is supported */
bool
tensor_core_available
()
const
;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template
<
typename
Callback
>
inline
void
TensorCoreCublasCallIfAvailable
(
Callback
&&
callback
)
const
{
if
(
cublas_tensor_core_handle_
)
{
cublas_tensor_core_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
else
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
}
/*! \brief Return cudnn handle in the device context. */
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
cudnnHandle_t
cudnn_handle
()
const
;
...
@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
template
<
typename
Callback
>
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
();
callback
();
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
}
}
...
@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template
<
typename
Callback
>
void
CublasCall
(
Callback
callback
,
cublasMath_t
new_math
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
cublas_mtx_
);
ScopedCublasMathMode
scoped_cublas_math
(
cublas_handle_
,
new_math
);
callback
();
}
#endif
private:
private:
CUDAPlace
place_
;
CUDAPlace
place_
;
...
@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
cudaStream_t
stream_
;
cudaStream_t
stream_
;
cublasHandle_t
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_tensor_core_handle_
;
int
compute_capability_
;
int
compute_capability_
;
int
runtime_version_
;
int
runtime_version_
;
...
@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
int
multi_process_
;
int
multi_process_
;
int
max_threads_per_mp_
;
int
max_threads_per_mp_
;
mutable
std
::
mutex
mtx_
;
// StreamCallbackManager is thread-safe
// StreamCallbackManager is thread-safe
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
mutable
std
::
mutex
cublas_mtx_
;
DISABLE_COPY_AND_ASSIGN
(
CUDADeviceContext
)
;
};
};
template
<
>
template
<
>
...
...
paddle/fluid/platform/device_context_test.cu
浏览文件 @
68a07328
...
@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
...
@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE
(
nullptr
,
gpu_device
);
ASSERT_NE
(
nullptr
,
gpu_device
);
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
delete
device_context
;
}
}
}
}
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
68a07328
...
@@ -214,16 +214,18 @@ class MKLDNNHandler {
...
@@ -214,16 +214,18 @@ class MKLDNNHandler {
std
::
string
*
key
,
const
mkldnn
::
memory
::
dims
&
input_dims
,
std
::
string
*
key
,
const
mkldnn
::
memory
::
dims
&
input_dims
,
const
mkldnn
::
memory
::
dims
&
weights_dims
,
const
std
::
vector
<
int
>&
strides
,
const
mkldnn
::
memory
::
dims
&
weights_dims
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
int
&
groups
,
const
mkldnn
::
memory
::
data_type
&
type
,
const
int
&
groups
,
const
mkldnn
::
memory
::
data_type
&
srcdt
,
const
mkldnn
::
memory
::
format
&
format
,
const
std
::
string
&
suffix
)
{
const
mkldnn
::
memory
::
format
&
format
,
const
mkldnn
::
memory
::
data_type
&
dstdt
,
const
std
::
string
&
suffix
)
{
AppendKeyDims
(
key
,
input_dims
);
AppendKeyDims
(
key
,
input_dims
);
AppendKeyDims
(
key
,
weights_dims
);
AppendKeyDims
(
key
,
weights_dims
);
AppendKeyVec
(
key
,
strides
);
AppendKeyVec
(
key
,
strides
);
AppendKeyVec
(
key
,
paddings
);
AppendKeyVec
(
key
,
paddings
);
AppendKeyVec
(
key
,
dilations
);
AppendKeyVec
(
key
,
dilations
);
AppendKey
(
key
,
std
::
to_string
(
groups
));
AppendKey
(
key
,
std
::
to_string
(
groups
));
AppendKey
(
key
,
std
::
to_string
(
type
));
AppendKey
(
key
,
std
::
to_string
(
srcdt
));
AppendKey
(
key
,
std
::
to_string
(
format
));
AppendKey
(
key
,
std
::
to_string
(
format
));
AppendKey
(
key
,
std
::
to_string
(
dstdt
));
AppendKey
(
key
,
suffix
);
AppendKey
(
key
,
suffix
);
}
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
68a07328
...
@@ -946,13 +946,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -946,13 +946,6 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is STR, debug_graphviz_path indicate the path that
R"DOC(The type is STR, debug_graphviz_path indicate the path that
writing the SSA Graph to file in the form of graphviz, you.
writing the SSA Graph to file in the form of graphviz, you.
It is useful for debugging. Default "")DOC"
)
It is useful for debugging. Default "")DOC"
)
.
def_property
(
"enable_data_balance"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_data_balance_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE
(
!
self
.
IsFinalized
(),
"BuildStrategy is finlaized."
);
self
.
enable_data_balance_
=
b
;
})
// FIXME(chengudo): enable_data_balance seems not important
.
def_property
(
.
def_property
(
"enable_sequential_execution"
,
"enable_sequential_execution"
,
[](
const
BuildStrategy
&
self
)
{
[](
const
BuildStrategy
&
self
)
{
...
@@ -1007,6 +1000,10 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1007,6 +1000,10 @@ All parameter, weight, gradient are variables in Paddle.
"memory_optimize"
,
"memory_optimize"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_optimize_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_optimize_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
memory_optimize_
=
b
;
})
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
memory_optimize_
=
b
;
})
.
def_property
(
"is_distribution"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
is_distribution_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
is_distribution_
=
b
;
})
.
def_property
(
.
def_property
(
"memory_early_delete"
,
"memory_early_delete"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_early_delete_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_early_delete_
;
},
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
68a07328
...
@@ -29,6 +29,15 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
...
@@ -29,6 +29,15 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy
=
core
.
ParallelExecutor
.
BuildStrategy
BuildStrategy
=
core
.
ParallelExecutor
.
BuildStrategy
def
_is_pserver_mode
(
main_program
):
main
=
main_program
if
main_program
\
else
framework
.
default_main_program
()
for
op
in
main
.
global_block
().
ops
:
if
op
.
type
in
[
"send"
,
"recv"
]:
return
True
return
False
class
ParallelExecutor
(
object
):
class
ParallelExecutor
(
object
):
"""
"""
ParallelExecutor is designed for data parallelism, which focuses on distributing
ParallelExecutor is designed for data parallelism, which focuses on distributing
...
@@ -128,6 +137,11 @@ class ParallelExecutor(object):
...
@@ -128,6 +137,11 @@ class ParallelExecutor(object):
build_strategy
=
BuildStrategy
()
build_strategy
=
BuildStrategy
()
build_strategy
.
num_trainers
=
num_trainers
build_strategy
.
num_trainers
=
num_trainers
build_strategy
.
trainer_id
=
trainer_id
build_strategy
.
trainer_id
=
trainer_id
# FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
# num_trainers is 1, so the current fields of build_strategy doesn't tell if
# it's distributed model.
build_strategy
.
is_distribution
=
_is_pserver_mode
(
main_program
)
or
num_trainers
>
1
# step4: get main_program, scope, local_scopes
# step4: get main_program, scope, local_scopes
main
=
main_program
if
main_program
\
main
=
main_program
if
main_program
\
...
...
python/paddle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py
浏览文件 @
68a07328
...
@@ -47,7 +47,8 @@ class TestConv2dInt8Op(TestConv2dOp):
...
@@ -47,7 +47,8 @@ class TestConv2dInt8Op(TestConv2dOp):
self
.
init_group
()
self
.
init_group
()
self
.
init_dilation
()
self
.
init_dilation
()
self
.
init_test_case
()
self
.
init_test_case
()
self
.
init_dtype
()
self
.
init_fuse_relu
()
self
.
init_data_type
()
conv2d_param
=
{
conv2d_param
=
{
'stride'
:
self
.
stride
,
'stride'
:
self
.
stride
,
...
@@ -78,7 +79,11 @@ class TestConv2dInt8Op(TestConv2dOp):
...
@@ -78,7 +79,11 @@ class TestConv2dInt8Op(TestConv2dOp):
np
.
round
((
input_shift
)
*
self
.
scale_in
).
astype
(
np
.
int32
),
np
.
round
((
input_shift
)
*
self
.
scale_in
).
astype
(
np
.
int32
),
filter_int
,
self
.
groups
,
filter_int
,
self
.
groups
,
conv2d_param
).
astype
(
np
.
float32
)
*
scale_output_shift
conv2d_param
).
astype
(
np
.
float32
)
*
scale_output_shift
output
=
np
.
round
(
output1
-
output2
).
astype
(
self
.
dsttype
)
if
self
.
fuse_relu
:
output
=
np
.
maximum
(
np
.
round
(
output1
-
output2
),
0
).
astype
(
self
.
dsttype
)
else
:
output
=
np
.
round
(
output1
-
output2
).
astype
(
self
.
dsttype
)
else
:
else
:
filter_int
=
np
.
round
(
filter
*
filter_int
=
np
.
round
(
filter
*
self
.
scale_weights
[
0
]).
astype
(
np
.
int32
)
self
.
scale_weights
[
0
]).
astype
(
np
.
int32
)
...
@@ -87,7 +92,15 @@ class TestConv2dInt8Op(TestConv2dOp):
...
@@ -87,7 +92,15 @@ class TestConv2dInt8Op(TestConv2dOp):
output1
=
conv2d_forward_refer
(
output1
=
conv2d_forward_refer
(
input
.
astype
(
np
.
int32
),
filter_int
,
self
.
groups
,
input
.
astype
(
np
.
int32
),
filter_int
,
self
.
groups
,
conv2d_param
).
astype
(
np
.
float32
)
conv2d_param
).
astype
(
np
.
float32
)
output
=
np
.
round
(
output1
*
scale_output_shift
).
astype
(
self
.
dsttype
)
if
self
.
fuse_relu
:
output
=
np
.
maximum
(
np
.
round
(
output1
*
(
self
.
scale_out
/
(
self
.
scale_in
*
self
.
scale_weights
[
0
]))),
0
).
astype
(
self
.
dsttype
)
else
:
output
=
np
.
round
(
output1
*
(
self
.
scale_out
/
(
self
.
scale_in
*
self
.
scale_weights
[
0
]))).
astype
(
self
.
dsttype
)
self
.
inputs
=
{
self
.
inputs
=
{
'Input'
:
'Input'
:
...
@@ -106,6 +119,7 @@ class TestConv2dInt8Op(TestConv2dOp):
...
@@ -106,6 +119,7 @@ class TestConv2dInt8Op(TestConv2dOp):
'Scale_in'
:
self
.
scale_in
,
'Scale_in'
:
self
.
scale_in
,
'Scale_out'
:
self
.
scale_out
,
'Scale_out'
:
self
.
scale_out
,
'Scale_weights'
:
self
.
scale_weights
,
'Scale_weights'
:
self
.
scale_weights
,
'fuse_relu'
:
self
.
fuse_relu
}
}
self
.
outputs
=
{
'Output'
:
output
}
self
.
outputs
=
{
'Output'
:
output
}
...
@@ -129,12 +143,15 @@ class TestConv2dInt8Op(TestConv2dOp):
...
@@ -129,12 +143,15 @@ class TestConv2dInt8Op(TestConv2dOp):
self
.
scale_out
=
0.5
self
.
scale_out
=
0.5
self
.
scale_weights
=
[
10.0
]
self
.
scale_weights
=
[
10.0
]
def
init_dtype
(
self
):
def
init_d
ata_
type
(
self
):
self
.
srctype
=
np
.
uint8
self
.
srctype
=
np
.
uint8
self
.
dsttype
=
np
.
int8
self
.
dsttype
=
np
.
int8
def
init_fuse_relu
(
self
):
self
.
fuse_relu
=
True
#--------------------test conv2d u8 in and s8 out--------------------
#--------------------test conv2d u8 in and u8 out--------------------
class
TestConv2d
(
TestConv2dInt8Op
):
class
TestConv2d
(
TestConv2dInt8Op
):
...
@@ -203,18 +220,43 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op):
...
@@ -203,18 +220,43 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op):
self
.
groups
=
3
self
.
groups
=
3
#--------------------test conv2d s8 in and s8 out--------------------
def
init_data_type_with_fusion
(
self
,
input_dt
,
fuse_relu
):
self
.
srctype
=
input_dt
self
.
dsttype
=
np
.
uint8
if
fuse_relu
else
np
.
int8
def
init_fuse_relu
(
self
):
self
.
fuse_relu
=
fuse_relu
def
create_test_int8_class
(
parent
):
def
create_test_int8_class
(
parent
):
class
TestInt8Case
(
parent
):
def
init_dtype
(
self
):
#--------------------test conv2d s8 in and u8 out--------------------
self
.
srctype
=
np
.
int8
self
.
dsttype
=
np
.
int8
class
TestS8U8Case
(
parent
):
def
init_data_type
(
self
):
cls_name
=
"{0}_{1}"
.
format
(
parent
.
__name__
,
"s8s8"
)
init_data_type_with_fusion
(
self
,
np
.
int8
,
True
)
TestInt8Case
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestInt8Case
#--------------------test conv2d s8 in and s8 out--------------------
class
TestS8S8Case
(
parent
):
def
init_data_type
(
self
):
init_data_type_with_fusion
(
self
,
np
.
int8
,
False
)
#--------------------test conv2d u8 in and s8 out--------------------
class
TestU8S8Case
(
parent
):
def
init_data_type
(
self
):
init_data_type_with_fusion
(
self
,
np
.
uint8
,
False
)
cls_name_s8u8
=
"{0}_relu_{1}"
.
format
(
parent
.
__name__
,
"1"
)
cls_name_s8s8
=
"{0}_relu_{1}"
.
format
(
parent
.
__name__
,
"0"
)
cls_name_u8s8
=
"{0}_relu_{1}"
.
format
(
parent
.
__name__
,
"0"
)
TestS8U8Case
.
__name__
=
cls_name_s8u8
TestS8S8Case
.
__name__
=
cls_name_s8s8
TestU8S8Case
.
__name__
=
cls_name_u8s8
globals
()[
cls_name_s8u8
]
=
TestS8U8Case
globals
()[
cls_name_s8s8
]
=
TestS8S8Case
globals
()[
cls_name_u8s8
]
=
TestU8S8Case
create_test_int8_class
(
TestConv2dInt8Op
)
create_test_int8_class
(
TestConv2dInt8Op
)
...
...
python/paddle/fluid/tests/unittests/test_reader_reset.py
浏览文件 @
68a07328
...
@@ -75,8 +75,6 @@ class TestReaderReset(unittest.TestCase):
...
@@ -75,8 +75,6 @@ class TestReaderReset(unittest.TestCase):
exe
.
run
(
startup_prog
)
exe
.
run
(
startup_prog
)
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
if
with_double_buffer
:
build_strategy
.
enable_data_balance
=
True
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
=
fluid
.
ExecutionStrategy
()
parallel_exe
=
fluid
.
ParallelExecutor
(
parallel_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
self
.
use_cuda
,
use_cuda
=
self
.
use_cuda
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录