Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d0e3b240
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d0e3b240
编写于
1月 09, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into fix-dist-sparse-decay
test=develop
上级
c3b9edf9
223cc89f
变更
47
隐藏空白更改
内联
并排
Showing
47 changed file
with
1239 addition
and
197 deletion
+1239
-197
cmake/FindJeMalloc.cmake
cmake/FindJeMalloc.cmake
+7
-0
cmake/cuda.cmake
cmake/cuda.cmake
+13
-1
cmake/external/boost.cmake
cmake/external/boost.cmake
+2
-5
cmake/external/mkldnn.cmake
cmake/external/mkldnn.cmake
+1
-1
cmake/external/mklml.cmake
cmake/external/mklml.cmake
+16
-18
cmake/generic.cmake
cmake/generic.cmake
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+1
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/lock_free_optimize_pass.cc
paddle/fluid/framework/ir/lock_free_optimize_pass.cc
+358
-0
paddle/fluid/framework/ir/lock_free_optimize_pass.h
paddle/fluid/framework/ir/lock_free_optimize_pass.h
+130
-0
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+3
-2
paddle/fluid/framework/var_type_traits.cc
paddle/fluid/framework/var_type_traits.cc
+5
-3
paddle/fluid/framework/var_type_traits.h
paddle/fluid/framework/var_type_traits.h
+2
-2
paddle/fluid/framework/var_type_traits_test.cc
paddle/fluid/framework/var_type_traits_test.cc
+5
-4
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+2
-2
paddle/fluid/inference/analysis/passes/CMakeLists.txt
paddle/fluid/inference/analysis/passes/CMakeLists.txt
+1
-0
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+23
-0
paddle/fluid/operators/jit/gen/CMakeLists.txt
paddle/fluid/operators/jit/gen/CMakeLists.txt
+1
-0
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+85
-0
paddle/fluid/operators/jit/gen/seqpool.h
paddle/fluid/operators/jit/gen/seqpool.h
+214
-0
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+15
-0
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+6
-0
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+23
-0
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+7
-0
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
+1
-0
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+31
-0
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+26
-0
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+1
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+2
-0
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+24
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+49
-0
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+1
-1
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+64
-70
paddle/fluid/operators/math/sequence_pooling.cc
paddle/fluid/operators/math/sequence_pooling.cc
+21
-11
paddle/fluid/operators/ngraph/ops/binary_unnary_op.h
paddle/fluid/operators/ngraph/ops/binary_unnary_op.h
+0
-2
paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h
paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h
+0
-2
paddle/fluid/operators/ngraph/ops/fill_constant_op.h
paddle/fluid/operators/ngraph/ops/fill_constant_op.h
+0
-2
paddle/fluid/operators/ngraph/ops/mean_op.h
paddle/fluid/operators/ngraph/ops/mean_op.h
+0
-2
paddle/fluid/operators/ngraph/ops/mul_op.h
paddle/fluid/operators/ngraph/ops/mul_op.h
+0
-2
paddle/fluid/operators/ngraph/ops/scale_op.h
paddle/fluid/operators/ngraph/ops/scale_op.h
+0
-2
paddle/fluid/operators/ngraph/ops/top_k_op.h
paddle/fluid/operators/ngraph/ops/top_k_op.h
+0
-2
paddle/fluid/platform/cuda_helper.h
paddle/fluid/platform/cuda_helper.h
+58
-0
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+13
-5
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+24
-52
paddle/fluid/platform/device_context_test.cu
paddle/fluid/platform/device_context_test.cu
+0
-3
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-1
未找到文件。
cmake/FindJeMalloc.cmake
浏览文件 @
d0e3b240
...
...
@@ -19,3 +19,10 @@ find_package_handle_standard_args(jemalloc DEFAULT_MSG JEMALLOC_LIBRARIES JEMALL
mark_as_advanced
(
JEMALLOC_LIBRARIES
JEMALLOC_INCLUDE_DIR
)
if
(
JEMALLOC_FOUND
)
add_library
(
jemalloc::jemalloc UNKNOWN IMPORTED
)
set_target_properties
(
jemalloc::jemalloc PROPERTIES
IMPORTED_LOCATION
${
JEMALLOC_LIBRARIES
}
INTERFACE_INCLUDE_DIRECTORIES
"
${
JEMALLOC_INCLUDE_DIR
}
"
)
endif
()
cmake/cuda.cmake
浏览文件 @
d0e3b240
...
...
@@ -2,9 +2,11 @@ if(NOT WITH_GPU)
return
()
endif
()
set
(
paddle_known_gpu_archs
"30 35 50 52 60 61 70
75
"
)
set
(
paddle_known_gpu_archs
"30 35 50 52 60 61 70"
)
set
(
paddle_known_gpu_archs7
"30 35 50 52"
)
set
(
paddle_known_gpu_archs8
"30 35 50 52 60 61"
)
set
(
paddle_known_gpu_archs9
"30 35 50 52 60 61 70"
)
set
(
paddle_known_gpu_archs10
"30 35 50 52 60 61 70 75"
)
######################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
...
...
@@ -155,6 +157,16 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
# warning for now.
list
(
APPEND CUDA_NVCC_FLAGS
"-Wno-deprecated-gpu-targets"
)
add_definitions
(
"-DPADDLE_CUDA_BINVER=
\"
80
\"
"
)
elseif
(
${
CUDA_VERSION
}
LESS 10.0
)
# CUDA 9.x
set
(
paddle_known_gpu_archs
${
paddle_known_gpu_archs9
}
)
list
(
APPEND CUDA_NVCC_FLAGS
"-D_MWAITXINTRIN_H_INCLUDED"
)
list
(
APPEND CUDA_NVCC_FLAGS
"-D__STRICT_ANSI__"
)
add_definitions
(
"-DPADDLE_CUDA_BINVER=
\"
90
\"
"
)
elseif
(
${
CUDA_VERSION
}
LESS 11.0
)
# CUDA 10.x
set
(
paddle_known_gpu_archs
${
paddle_known_gpu_archs10
}
)
list
(
APPEND CUDA_NVCC_FLAGS
"-D_MWAITXINTRIN_H_INCLUDED"
)
list
(
APPEND CUDA_NVCC_FLAGS
"-D__STRICT_ANSI__"
)
add_definitions
(
"-DPADDLE_CUDA_BINVER=
\"
100
\"
"
)
endif
()
include_directories
(
${
CUDA_INCLUDE_DIRS
}
)
...
...
cmake/external/boost.cmake
浏览文件 @
d0e3b240
...
...
@@ -23,11 +23,8 @@ set(BOOST_PROJECT "extern_boost")
# checked that the devtools package of CentOS 6 installs boost 1.41.0.
# So we use 1.41.0 here.
set
(
BOOST_VER
"1.41.0"
)
if
((
NOT DEFINED BOOST_TAR
)
OR
(
NOT DEFINED BOOST_URL
))
message
(
STATUS
"use pre defined download url"
)
set
(
BOOST_TAR
"boost_1_41_0"
CACHE STRING
""
FORCE
)
set
(
BOOST_URL
"http://paddlepaddledeps.cdn.bcebos.com/
${
BOOST_TAR
}
.tar.gz"
CACHE STRING
""
FORCE
)
endif
()
set
(
BOOST_TAR
"boost_1_41_0"
CACHE STRING
""
FORCE
)
set
(
BOOST_URL
"http://paddlepaddledeps.cdn.bcebos.com/
${
BOOST_TAR
}
.tar.gz"
CACHE STRING
""
FORCE
)
MESSAGE
(
STATUS
"BOOST_TAR:
${
BOOST_TAR
}
, BOOST_URL:
${
BOOST_URL
}
"
)
...
...
cmake/external/mkldnn.cmake
浏览文件 @
d0e3b240
...
...
@@ -55,7 +55,7 @@ ExternalProject_Add(
${
MKLDNN_PROJECT
}
${
EXTERNAL_PROJECT_LOG_ARGS
}
DEPENDS
${
MKLDNN_DEPENDS
}
GIT_REPOSITORY
"https://github.com/
01org
/mkl-dnn.git"
GIT_REPOSITORY
"https://github.com/
intel
/mkl-dnn.git"
GIT_TAG
"830a10059a018cd2634d94195140cf2d8790a75a"
PREFIX
${
MKLDNN_SOURCES_DIR
}
UPDATE_COMMAND
""
...
...
cmake/external/mklml.cmake
浏览文件 @
d0e3b240
...
...
@@ -16,6 +16,12 @@ IF(NOT ${WITH_MKLML})
return
()
ENDIF
(
NOT
${
WITH_MKLML
}
)
IF
(
APPLE
)
MESSAGE
(
WARNING
"Mac is not supported with MKLML in Paddle yet. Force WITH_MKLML=OFF."
)
SET
(
WITH_MKLML OFF CACHE STRING
"Disable MKLML package in MacOS"
FORCE
)
return
()
ENDIF
()
INCLUDE
(
ExternalProject
)
SET
(
MKLML_DST_DIR
"mklml"
)
SET
(
MKLML_INSTALL_ROOT
"
${
THIRD_PARTY_PATH
}
/install"
)
...
...
@@ -23,32 +29,24 @@ SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR})
SET
(
MKLML_ROOT
${
MKLML_INSTALL_DIR
}
)
SET
(
MKLML_INC_DIR
${
MKLML_ROOT
}
/include
)
SET
(
MKLML_LIB_DIR
${
MKLML_ROOT
}
/lib
)
if
(
WIN32
)
SET
(
CMAKE_INSTALL_RPATH
"
${
CMAKE_INSTALL_RPATH
}
"
"
${
MKLML_ROOT
}
/lib"
)
SET
(
TIME_VERSION
"2019.0.1.20181227"
)
IF
(
WIN32
)
SET
(
MKLML_VER
"mklml_win_
${
TIME_VERSION
}
"
CACHE STRING
""
FORCE
)
SET
(
MKLML_URL
"https://paddlepaddledeps.cdn.bcebos.com/
${
MKLML_VER
}
.zip"
CACHE STRING
""
FORCE
)
SET
(
MKLML_LIB
${
MKLML_LIB_DIR
}
/mklml.lib
)
SET
(
MKLML_IOMP_LIB
${
MKLML_LIB_DIR
}
/libiomp5md.lib
)
SET
(
MKLML_SHARED_LIB
${
MKLML_LIB_DIR
}
/mklml.dll
)
SET
(
MKLML_SHARED_IOMP_LIB
${
MKLML_LIB_DIR
}
/libiomp5md.dll
)
else
()
ELSE
()
SET
(
MKLML_VER
"mklml_lnx_
${
TIME_VERSION
}
"
CACHE STRING
""
FORCE
)
SET
(
MKLML_URL
"http://paddlepaddledeps.cdn.bcebos.com/
${
MKLML_VER
}
.tgz"
CACHE STRING
""
FORCE
)
SET
(
MKLML_LIB
${
MKLML_LIB_DIR
}
/libmklml_intel.so
)
SET
(
MKLML_IOMP_LIB
${
MKLML_LIB_DIR
}
/libiomp5.so
)
SET
(
MKLML_SHARED_LIB
${
MKLML_LIB_DIR
}
/libmklml_intel.so
)
SET
(
MKLML_SHARED_IOMP_LIB
${
MKLML_LIB_DIR
}
/libiomp5.so
)
endif
()
SET
(
CMAKE_INSTALL_RPATH
"
${
CMAKE_INSTALL_RPATH
}
"
"
${
MKLML_ROOT
}
/lib"
)
IF
((
NOT DEFINED MKLML_VER
)
OR
(
NOT DEFINED MKLML_URL
))
MESSAGE
(
STATUS
"use pre defined download url"
)
if
(
WIN32
)
SET
(
MKLML_VER
"mklml_win_2019.0.1.20180928"
CACHE STRING
""
FORCE
)
SET
(
MKLML_URL
"https://paddlepaddledeps.cdn.bcebos.com/
${
MKLML_VER
}
.zip"
CACHE STRING
""
FORCE
)
elseif
(
APPLE
)
SET
(
MKLML_VER
"mklml_mac_2019.0.1.20180928"
CACHE STRING
""
FORCE
)
SET
(
MKLML_URL
"http://paddlepaddledeps.cdn.bcebos.com/
${
MKLML_VER
}
.tgz"
CACHE STRING
""
FORCE
)
else
()
SET
(
MKLML_VER
"mklml_lnx_2019.0.1.20180928"
CACHE STRING
""
FORCE
)
SET
(
MKLML_URL
"http://paddlepaddledeps.cdn.bcebos.com/
${
MKLML_VER
}
.tgz"
CACHE STRING
""
FORCE
)
ENDIF
()
endif
()
ENDIF
()
SET
(
MKLML_PROJECT
"extern_mklml"
)
MESSAGE
(
STATUS
"MKLML_VER:
${
MKLML_VER
}
, MKLML_URL:
${
MKLML_URL
}
"
)
...
...
cmake/generic.cmake
浏览文件 @
d0e3b240
...
...
@@ -117,7 +117,7 @@ function(common_link TARGET_NAME)
endif
()
if
(
WITH_JEMALLOC
)
target_link_libraries
(
${
TARGET_NAME
}
${
JEMALLOC_LIBRARIES
}
)
target_link_libraries
(
${
TARGET_NAME
}
jemalloc::jemalloc
)
endif
()
endfunction
()
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
d0e3b240
...
...
@@ -94,4 +94,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
memory_optimize_pass
)
memory_optimize_pass
lock_free_optimize_pass
)
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
d0e3b240
...
...
@@ -232,3 +232,4 @@ USE_PASS(analysis_var_pass);
USE_PASS
(
sequential_execution_pass
);
USE_PASS
(
all_reduce_deps_pass
);
USE_PASS
(
modify_op_lock_and_record_event_pass
);
USE_PASS
(
lock_free_optimize_pass
);
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
d0e3b240
...
...
@@ -31,6 +31,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library
(
graph_to_program_pass base
)
pass_library
(
graph_viz_pass base
)
pass_library
(
lock_free_optimize_pass base
)
pass_library
(
fc_fuse_pass inference
)
pass_library
(
attention_lstm_fuse_pass inference
)
pass_library
(
infer_clean_graph_pass inference
)
...
...
paddle/fluid/framework/ir/lock_free_optimize_pass.cc
0 → 100644
浏览文件 @
d0e3b240
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/lock_free_optimize_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
const
char
kSumGradOpName
[]
=
"sum"
;
// TODO(minqiyang): only support sgd at current time, please add
// other optimizers later.
const
char
kOptimizerType
[]
=
"sgd"
;
std
::
unique_ptr
<
ir
::
Graph
>
LockFreeOptimizePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
graph
.
get
());
// We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0)
std
::
unordered_set
<
std
::
string
>
weight_var_set
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
IsOpNamed
(
node
,
kOptimizerType
))
{
auto
&
param_out_vars
=
node
->
Op
()
->
Output
(
"ParamOut"
);
PADDLE_ENFORCE
(
param_out_vars
.
size
()
==
1u
);
weight_var_set
.
insert
(
param_out_vars
[
0
]);
}
}
// find all grad's merge op via weight name, where
// Grad0 <- SUM(Grad1, Grad2, Grad3 ...)
std
::
unordered_set
<
ir
::
Node
*>
grad_sum_op_set
;
for
(
ir
::
Node
*
node
:
graph
->
Nodes
())
{
if
(
IsOpNamed
(
node
,
kSumGradOpName
))
{
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
// strip the last grad suffix @GRAD
std
::
string
var_name
=
output
->
Name
();
const
std
::
string
suffix
(
kGradVarSuffix
);
if
(
var_name
!=
suffix
&&
var_name
.
size
()
>
suffix
.
size
()
&&
var_name
.
substr
(
var_name
.
size
()
-
suffix
.
size
())
==
suffix
)
{
// if so then strip them off
var_name
=
var_name
.
substr
(
0
,
var_name
.
size
()
-
suffix
.
size
());
if
(
weight_var_set
.
find
(
var_name
)
!=
weight_var_set
.
end
())
{
grad_sum_op_set
.
insert
(
node
);
break
;
}
}
}
}
}
// get the forward op and backward op pairs, where
// out <- forward(X, W)
// Grad1 <- backward(out, X')
// Grad0 <- SUM(Grad1, Grad2, Grad3 ...)
// W0 <- SGD(W1, Grad0)
for
(
ir
::
Node
*
node
:
grad_sum_op_set
)
{
for
(
ir
::
Node
*
merged_grad_var
:
node
->
outputs
)
{
// find the optimizers connected with sum op
if
(
IsVarNameEndsWith
(
merged_grad_var
,
kGradVarSuffix
)
&&
merged_grad_var
->
outputs
.
size
()
==
1u
)
{
ir
::
Node
*
opt_node
=
merged_grad_var
->
outputs
[
0
];
VLOG
(
3
)
<<
"Found opt node "
<<
opt_node
->
Name
();
// find the backward op connected with sum op
for
(
ir
::
Node
*
unmerged_grad_var
:
node
->
inputs
)
{
if
(
IsVarNameContains
(
unmerged_grad_var
,
kGradVarSuffix
)
&&
unmerged_grad_var
->
inputs
.
size
()
==
1u
)
{
ir
::
Node
*
backward_op
=
unmerged_grad_var
->
inputs
[
0
];
VLOG
(
3
)
<<
"Found backward_op "
<<
backward_op
->
Name
();
// find the forward op related to the backward op
ir
::
Node
*
forward_op
=
FindForwardOpViaBackwardOp
(
graph
.
get
(),
backward_op
);
VLOG
(
3
)
<<
"Found forward_op "
<<
forward_op
->
Name
();
PADDLE_ENFORCE
(
forward_op
);
Node
*
new_optimizer_node
=
CreateNewSGDNode
(
graph
.
get
(),
forward_op
,
backward_op
,
node
,
opt_node
);
PADDLE_ENFORCE
(
new_optimizer_node
);
}
}
}
}
}
// Remove the sum_op and its' outputs and connected Optimizers
for
(
Node
*
sum_op
:
grad_sum_op_set
)
{
for
(
Node
*
sum_op_output
:
sum_op
->
outputs
)
{
for
(
Node
*
optimize_op
:
sum_op_output
->
outputs
)
{
if
(
optimize_op
->
NodeType
()
==
Node
::
Type
::
kOperation
&&
optimize_op
->
Name
()
==
kOptimizerType
)
{
VLOG
(
3
)
<<
"remove optimize_op: "
<<
optimize_op
->
Name
()
<<
"_"
<<
optimize_op
->
id
();
graph
->
RemoveNode
(
optimize_op
);
}
}
VLOG
(
3
)
<<
"remove sum_op_output: "
<<
sum_op_output
->
Name
()
<<
"_"
<<
sum_op_output
->
id
();
graph
->
RemoveNode
(
sum_op_output
);
}
VLOG
(
3
)
<<
"remove sum_op: "
<<
sum_op
->
Name
()
<<
"_"
<<
sum_op
->
id
();
graph
->
RemoveNode
(
sum_op
);
}
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
Node
*
output_node
:
node
->
outputs
)
{
if
(
output_node
->
Name
()
==
"sgd"
)
{
VLOG
(
3
)
<<
"Node link to SGD: "
<<
node
->
Name
()
<<
"_"
<<
node
->
id
()
<<
" --> "
<<
output_node
->
Name
()
<<
"_"
<<
output_node
->
id
();
for
(
Node
*
input_node
:
node
->
inputs
)
{
VLOG
(
3
)
<<
"SGD Input link: "
<<
input_node
->
Name
()
<<
"_"
<<
input_node
->
id
()
<<
" --> "
<<
node
->
Name
()
<<
"_"
<<
node
->
id
();
}
}
}
}
return
graph
;
}
ir
::
Node
*
LockFreeOptimizePass
::
CreateNewSGDNode
(
ir
::
Graph
*
graph
,
ir
::
Node
*
forward_node
,
ir
::
Node
*
backward_node
,
ir
::
Node
*
grad_sum_node
,
ir
::
Node
*
optimize_node
)
const
{
PADDLE_ENFORCE
(
graph
);
PADDLE_ENFORCE
(
forward_node
);
PADDLE_ENFORCE
(
backward_node
);
PADDLE_ENFORCE
(
grad_sum_node
);
PADDLE_ENFORCE
(
optimize_node
);
// find the grad var node between the grad sum node and backward_node
std
::
vector
<
ir
::
Node
*>
grad_vars
=
FindConnectedNode
(
backward_node
,
grad_sum_node
);
ir
::
Node
*
grad_node
=
nullptr
;
for
(
ir
::
Node
*
node
:
grad_vars
)
{
if
(
!
ir
::
IsControlDepVar
(
*
node
))
{
grad_node
=
node
;
}
}
PADDLE_ENFORCE
(
grad_node
);
// create a new SGD node
OpDesc
*
old_desc
=
optimize_node
->
Op
();
// keep with the same block between new optimizer and the old one
OpDesc
new_desc
(
*
old_desc
,
old_desc
->
Block
());
new_desc
.
SetInput
(
"Param"
,
old_desc
->
Input
(
"Param"
));
new_desc
.
SetInput
(
"LearningRate"
,
old_desc
->
Input
(
"LearningRate"
));
new_desc
.
SetInput
(
"Grad"
,
std
::
vector
<
std
::
string
>
({
grad_node
->
Name
()}));
new_desc
.
SetOutput
(
"ParamOut"
,
old_desc
->
Output
(
"ParamOut"
));
std
::
vector
<
std
::
string
>
op_role_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
new_desc
.
GetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
// replace the second op role var, because the grad name was
// changed in new optimizer
op_role_vars
.
pop_back
();
op_role_vars
.
push_back
(
grad_node
->
Name
());
new_desc
.
SetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
(),
op_role_vars
);
new_desc
.
SetType
(
kOptimizerType
);
// set backward op's op role var, this will be used to
// set device_id in multi_device_pass
backward_node
->
Op
()
->
SetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
(),
op_role_vars
);
// backward_node->Op()->SetAttr(
// framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), {});
// keep with the same output nodes between new optimizer and the
// old one
Node
*
sgd_node
=
graph
->
CreateOpNode
(
&
new_desc
);
// change all outputs of the optimize_node to the new one
ReplaceAllDownstreamNode
(
optimize_node
,
sgd_node
);
// find connected node between forward node and optimize node
// and replace the optimize node to new sgd node
std
::
vector
<
ir
::
Node
*>
forward_opt_connected_nodes
=
FindConnectedNode
(
forward_node
,
optimize_node
);
for
(
ir
::
Node
*
node
:
forward_opt_connected_nodes
)
{
ReplaceUpstreamNode
(
node
,
optimize_node
,
sgd_node
);
}
// find connected node between backward node and optimize node
// and replace the optimize node to new sgd node
std
::
vector
<
ir
::
Node
*>
backward_opt_connected_nodes
=
FindConnectedNode
(
backward_node
,
optimize_node
);
for
(
ir
::
Node
*
node
:
backward_opt_connected_nodes
)
{
ReplaceUpstreamNode
(
node
,
optimize_node
,
sgd_node
);
}
// SGD must have only one param and LR in
PADDLE_ENFORCE
(
old_desc
->
Input
(
"LearningRate"
).
size
()
==
1u
);
PADDLE_ENFORCE
(
old_desc
->
Input
(
"Param"
).
size
()
==
1u
);
// LR and weight nodes should be copied
for
(
Node
*
upstream_node
:
optimize_node
->
inputs
)
{
if
(
upstream_node
->
Name
()
==
old_desc
->
Input
(
"LearningRate"
)[
0
]
||
upstream_node
->
Name
()
==
old_desc
->
Input
(
"Param"
)[
0
])
{
ReplaceUpstreamNode
(
upstream_node
,
optimize_node
,
sgd_node
);
}
}
VLOG
(
3
)
<<
"Create new opt node"
<<
sgd_node
->
Name
()
<<
"_"
<<
sgd_node
->
id
();
return
sgd_node
;
}
std
::
vector
<
ir
::
Node
*>
LockFreeOptimizePass
::
FindConnectedNode
(
ir
::
Node
*
upstream_node
,
ir
::
Node
*
downstream_node
)
const
{
std
::
vector
<
ir
::
Node
*>
result
;
for
(
ir
::
Node
*
out_node
:
upstream_node
->
outputs
)
{
for
(
ir
::
Node
*
in_node
:
downstream_node
->
inputs
)
{
if
(
in_node
==
out_node
)
{
result
.
push_back
(
in_node
);
}
}
}
return
result
;
}
void
LockFreeOptimizePass
::
ReplaceUpstreamNode
(
ir
::
Node
*
upstream_node
,
ir
::
Node
*
old_optimizer_node
,
ir
::
Node
*
new_optimizer_node
)
const
{
PADDLE_ENFORCE
(
upstream_node
);
PADDLE_ENFORCE
(
old_optimizer_node
);
PADDLE_ENFORCE
(
new_optimizer_node
);
// Remove the old_optimizer_node from upstream_node's outputs vector
auto
&
output_node_vec
=
upstream_node
->
outputs
;
for
(
auto
output_node_iter
=
output_node_vec
.
begin
();
output_node_iter
!=
output_node_vec
.
end
();)
{
if
(
*
output_node_iter
==
old_optimizer_node
)
{
output_node_vec
.
erase
(
output_node_iter
);
break
;
}
else
{
++
output_node_iter
;
}
}
// Add the new_optimizer_node to upstream_node's outputs vector
output_node_vec
.
emplace_back
(
new_optimizer_node
);
new_optimizer_node
->
inputs
.
emplace_back
(
upstream_node
);
}
void
LockFreeOptimizePass
::
ReplaceAllDownstreamNode
(
ir
::
Node
*
old_optimizer_node
,
ir
::
Node
*
new_optimizer_node
)
const
{
PADDLE_ENFORCE
(
old_optimizer_node
);
PADDLE_ENFORCE
(
new_optimizer_node
);
for
(
ir
::
Node
*
downstream_node
:
old_optimizer_node
->
outputs
)
{
// Remove the old_optimizer_node from downstream_node's inputs vector
auto
&
input_node_vec
=
downstream_node
->
inputs
;
for
(
auto
input_node_iter
=
input_node_vec
.
begin
();
input_node_iter
!=
input_node_vec
.
end
();)
{
if
(
*
input_node_iter
==
old_optimizer_node
)
{
input_node_vec
.
erase
(
input_node_iter
);
break
;
}
else
{
++
input_node_iter
;
}
}
// Add the new_optimizer_node to downstream_node's inputs vector
input_node_vec
.
emplace_back
(
new_optimizer_node
);
new_optimizer_node
->
outputs
.
emplace_back
(
downstream_node
);
}
}
ir
::
Node
*
LockFreeOptimizePass
::
FindForwardOpViaBackwardOp
(
ir
::
Graph
*
graph
,
ir
::
Node
*
backward_node
)
const
{
PADDLE_ENFORCE
(
graph
);
PADDLE_ENFORCE
(
backward_node
);
// strip the suffix _grad of backward_node's name
std
::
string
forward_op_name
=
backward_node
->
Name
();
const
std
::
string
suffix
(
"_grad"
);
if
(
forward_op_name
!=
suffix
&&
forward_op_name
.
size
()
>
suffix
.
size
()
&&
forward_op_name
.
substr
(
forward_op_name
.
size
()
-
suffix
.
size
())
==
suffix
)
{
// if so then strip them off
forward_op_name
=
forward_op_name
.
substr
(
0
,
forward_op_name
.
size
()
-
suffix
.
size
());
}
else
{
LOG
(
WARNING
)
<<
"Illegal backward node's name "
<<
backward_node
->
Name
()
<<
" id "
<<
backward_node
->
id
();
return
nullptr
;
}
for
(
ir
::
Node
*
node
:
graph
->
Nodes
())
{
if
(
node
->
Name
()
==
forward_op_name
)
{
if
(
node
->
outputs
.
size
()
==
0u
)
{
// if forward_node has no output, then it has NO grad op
continue
;
}
// check whether all inputs of the backward_op that ends_with @GRAD
// comes from the output of forward_op is the input of the backward_op
bool
is_related_forward_node
=
true
;
for
(
ir
::
Node
*
backward_input
:
backward_node
->
inputs
)
{
if
(
IsVarNameEndsWith
(
backward_input
,
kGradVarSuffix
))
{
bool
meets_correct_output
=
false
;
for
(
ir
::
Node
*
forward_output
:
node
->
outputs
)
{
if
(
forward_output
->
Name
()
+
kGradVarSuffix
==
backward_input
->
Name
())
{
meets_correct_output
=
true
;
break
;
}
}
if
(
!
meets_correct_output
)
{
is_related_forward_node
=
false
;
break
;
}
}
}
if
(
is_related_forward_node
)
{
return
node
;
}
}
}
return
nullptr
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
lock_free_optimize_pass
,
paddle
::
framework
::
ir
::
LockFreeOptimizePass
);
paddle/fluid/framework/ir/lock_free_optimize_pass.h
0 → 100644
浏览文件 @
d0e3b240
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
#define PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
#include <string>
#include <vector>
#include <boost/algorithm/string/predicate.hpp>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Node
;
/*
* Remove the sum op of all gradients of the backward op.
* And remove the dependecies of the optimizer related to the
* same backward op.
*
* Before this pass:
*
* forward_op1 forward_op2
* | |
* grad_op1 grad_op2
* \ /
* \ /
* sum_op
* |
* sgd_op
*
* After this pass:
* forward_op1 forward_op2
* | |
* grad_op1 grad_op2
* | |
* sgd_op1 sgd_op2
*
* sgd_op1 and sgd_op2 will update the same weight which holds the same
* memory, so we could benefits from the acceleration
*/
class
LockFreeOptimizePass
:
public
Pass
{
public:
virtual
~
LockFreeOptimizePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
private:
// Create a new sgd node via current optimizer node
ir
::
Node
*
CreateNewSGDNode
(
ir
::
Graph
*
graph
,
ir
::
Node
*
forward_node
,
ir
::
Node
*
backward_node
,
ir
::
Node
*
grad_sum_node
,
ir
::
Node
*
optimize_node
)
const
;
// Replace the input weight's optimizers
void
ReplaceUpstreamNode
(
ir
::
Node
*
upstream_node
,
ir
::
Node
*
old_optimizer_node
,
ir
::
Node
*
new_optimizer_node
)
const
;
// Replace the output weight's optimizers
void
ReplaceAllDownstreamNode
(
ir
::
Node
*
old_optimizer_node
,
ir
::
Node
*
new_optimizer_node
)
const
;
// Find all weight variables in graph
bool
FindAllWeightVars
(
ir
::
Graph
*
graph
)
const
;
// Find the forward_op node via the backward_op node
ir
::
Node
*
FindForwardOpViaBackwardOp
(
ir
::
Graph
*
graph
,
ir
::
Node
*
backward_node
)
const
;
std
::
vector
<
ir
::
Node
*>
FindConnectedNode
(
ir
::
Node
*
upstream_node
,
ir
::
Node
*
downstream_node
)
const
;
inline
bool
IsOpNamed
(
ir
::
Node
*
node
,
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
node
);
return
node
->
NodeType
()
==
Node
::
Type
::
kOperation
&&
node
->
Name
()
==
name
;
}
inline
bool
IsVarNamed
(
ir
::
Node
*
node
,
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
node
);
return
node
->
NodeType
()
==
Node
::
Type
::
kVariable
&&
node
->
Name
()
==
name
;
}
inline
bool
IsVarNameEndsWith
(
ir
::
Node
*
node
,
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
node
);
return
node
->
NodeType
()
==
Node
::
Type
::
kVariable
&&
boost
::
algorithm
::
ends_with
(
node
->
Name
(),
name
);
}
inline
bool
IsVarNameContains
(
ir
::
Node
*
node
,
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
node
);
return
node
->
NodeType
()
==
Node
::
Type
::
kVariable
&&
node
->
Name
().
find
(
name
)
!=
std
::
string
::
npos
;
}
inline
bool
IsControlDepFrom
(
ir
::
Node
*
ctrl_dep_node
,
ir
::
Node
*
node
)
const
{
PADDLE_ENFORCE
(
ctrl_dep_node
);
PADDLE_ENFORCE
(
node
);
return
IsControlDepVar
(
*
ctrl_dep_node
)
&&
ctrl_dep_node
->
inputs
.
size
()
>=
1u
&&
ctrl_dep_node
->
inputs
[
0
]
==
node
;
}
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
paddle/fluid/framework/scope.cc
浏览文件 @
d0e3b240
...
...
@@ -87,11 +87,12 @@ Variable* Scope::Var(const std::string& name) {
}
Variable
*
Scope
::
Var
(
std
::
string
*
name
)
{
auto
new_name
=
string
::
Sprintf
(
"%p.%d"
,
this
,
vars_
.
size
());
SCOPE_VARS_WRITER_LOCK
auto
new_name
=
std
::
to_string
(
reinterpret_cast
<
uintptr_t
>
(
this
))
+
"."
+
std
::
to_string
(
vars_
.
size
());
if
(
name
!=
nullptr
)
{
*
name
=
new_name
;
}
SCOPE_VARS_WRITER_LOCK
return
VarInternal
(
new_name
);
}
...
...
paddle/fluid/framework/var_type_traits.cc
浏览文件 @
d0e3b240
...
...
@@ -105,13 +105,15 @@ struct VarIdToTypeIndexMapHolder {
}
// namespace detail
const
std
::
type_index
&
ToTypeIndex
(
int
var_id
)
{
const
std
::
type_index
&
VarTraitId
ToTypeIndex
(
int
var_id
)
{
return
detail
::
VarIdToTypeIndexMapHolder
::
ToTypeIndex
(
var_id
);
}
const
char
*
ToTypeName
(
int
var_id
)
{
return
ToTypeIndex
(
var_id
).
name
();
}
const
char
*
ToTypeName
(
int
var_id
)
{
return
VarTraitIdToTypeIndex
(
var_id
).
name
();
}
int
T
oType
Id
(
const
std
::
type_index
&
type
)
{
int
T
ypeIndexToVarTrait
Id
(
const
std
::
type_index
&
type
)
{
return
detail
::
VarIdToTypeIndexMapHolder
::
ToTypeId
(
type
);
}
...
...
paddle/fluid/framework/var_type_traits.h
浏览文件 @
d0e3b240
...
...
@@ -66,8 +66,8 @@ namespace paddle {
namespace
framework
{
const
char
*
ToTypeName
(
int
var_id
);
const
std
::
type_index
&
ToTypeIndex
(
int
var_id
);
int
T
oType
Id
(
const
std
::
type_index
&
type
);
const
std
::
type_index
&
VarTraitId
ToTypeIndex
(
int
var_id
);
int
T
ypeIndexToVarTrait
Id
(
const
std
::
type_index
&
type
);
namespace
detail
{
...
...
paddle/fluid/framework/var_type_traits_test.cc
浏览文件 @
d0e3b240
...
...
@@ -45,10 +45,11 @@ struct TypeIndexChecker {
constexpr
auto
kId
=
VarTypeTrait
<
Type
>::
kId
;
std
::
type_index
actual_type
(
typeid
(
Type
));
EXPECT_EQ
(
std
::
string
(
ToTypeName
(
kId
)),
std
::
string
(
actual_type
.
name
()));
EXPECT_EQ
(
ToTypeIndex
(
kId
),
actual_type
);
EXPECT_EQ
(
ToTypeId
(
actual_type
),
kId
);
EXPECT_EQ
(
ToTypeIndex
(
ToTypeId
(
actual_type
)),
actual_type
);
EXPECT_EQ
(
ToTypeId
(
ToTypeIndex
(
kId
)),
kId
);
EXPECT_EQ
(
VarTraitIdToTypeIndex
(
kId
),
actual_type
);
EXPECT_EQ
(
TypeIndexToVarTraitId
(
actual_type
),
kId
);
EXPECT_EQ
(
VarTraitIdToTypeIndex
(
TypeIndexToVarTraitId
(
actual_type
)),
actual_type
);
EXPECT_EQ
(
TypeIndexToVarTraitId
(
VarTraitIdToTypeIndex
(
kId
)),
kId
);
EXPECT_TRUE
(
var_id_set
->
count
(
kId
)
==
0
);
// NOLINT
EXPECT_TRUE
(
type_index_set
->
count
(
actual_type
)
==
0
);
// NOLINT
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
d0e3b240
...
...
@@ -80,8 +80,8 @@ void TestWord2vecPrediction(const std::string& model_path) {
i
++
)
{
LOG
(
INFO
)
<<
"data: "
<<
static_cast
<
float
*>
(
outputs
.
front
().
data
.
data
())[
i
]
<<
" result: "
<<
result
[
i
];
PADDLE_ENFORCE
(
static_cast
<
float
*>
(
outputs
.
front
().
data
.
data
())
[
i
],
result
[
i
]
);
EXPECT_NEAR
(
static_cast
<
float
*>
(
outputs
.
front
().
data
.
data
())[
i
],
result
[
i
],
1e-3
);
}
}
...
...
paddle/fluid/inference/analysis/passes/CMakeLists.txt
浏览文件 @
d0e3b240
...
...
@@ -7,4 +7,5 @@ set(analysis_deps ${analysis_deps}
ir_graph_build_pass
ir_analysis_pass
analysis_passes
subgraph_detector
CACHE INTERNAL
""
)
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
d0e3b240
...
...
@@ -190,6 +190,26 @@ void BenchGRUKernel() {
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
BenchSeqPoolKernel
()
{
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
,
jit
::
SeqPoolType
::
kAvg
,
jit
::
SeqPoolType
::
kSqrt
};
for
(
auto
type
:
pool_types
)
{
for
(
int
w
:
TestSizes
())
{
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
for
(
int
h
:
TestSizes
())
{
attr
.
h
=
h
;
std
::
vector
<
T
>
x
(
h
*
w
),
y
(
w
);
RandomVec
<
T
>
(
h
*
w
,
x
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
x_data
=
x
.
data
();
T
*
y_data
=
y
.
data
();
BenchAllImpls
<
KT
,
jit
::
SeqPoolTuples
<
T
>
,
PlaceType
>
(
attr
,
x_data
,
y_data
,
&
attr
);
}
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
...
...
@@ -228,4 +248,7 @@ int main(int argc, char* argv[]) {
BenchGRUKernel
<
jit
::
kGRUH1
,
T
,
PlaceType
>
();
BenchGRUKernel
<
jit
::
kGRUHtPart1
,
T
,
PlaceType
>
();
BenchGRUKernel
<
jit
::
kGRUHtPart2
,
T
,
PlaceType
>
();
// seq pool function
BenchSeqPoolKernel
<
jit
::
kSeqPool
,
T
,
PlaceType
>
();
}
paddle/fluid/operators/jit/gen/CMakeLists.txt
浏览文件 @
d0e3b240
...
...
@@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN
(
kGRUHtPart1
)
USE_JITKERNEL_GEN
(
kGRUHtPart2
)
USE_JITKERNEL_GEN
(
kNCHW16CMulNC
)
USE_JITKERNEL_GEN
(
kSeqPool
)
paddle/fluid/operators/jit/gen/seqpool.cc
0 → 100644
浏览文件 @
d0e3b240
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
void
SeqPoolJitCode
::
genCode
()
{
constexpr
int
block
=
YMM_FLOAT_BLOCK
;
constexpr
int
max_num_regs
=
8
;
const
int
num_block
=
w_
/
block
;
const
int
num_groups
=
num_block
/
max_num_regs
;
int
rest_num_regs
=
num_block
%
max_num_regs
;
mov
(
reg32_int_h
,
dword
[
param_attr
]);
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovups
(
xmm_t
(
1
),
ptr
[
reg_tmp
+
OFFSET_EXP_ONE
]);
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
fp_h_
));
fild
(
dword
[
param_attr
]);
fstp
(
dword
[
reg_tmp
]);
vmovss
(
xmm_t
(
0
),
ptr
[
reg_tmp
]);
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
vsqrtps
(
xmm_t
(
0
),
xmm_t
(
0
));
}
vdivps
(
xmm_t
(
1
),
xmm_t
(
1
),
xmm_t
(
0
));
vmovss
(
ptr
[
reg_tmp
],
xmm_t
(
1
));
}
const
int
group_len
=
max_num_regs
*
block
*
sizeof
(
float
);
for
(
int
g
=
0
;
g
<
num_groups
;
++
g
)
{
pool_height
<
ymm_t
>
(
g
*
group_len
,
block
,
max_num_regs
);
}
if
(
rest_num_regs
>
0
)
{
pool_height
<
ymm_t
>
(
num_groups
*
group_len
,
block
,
rest_num_regs
);
}
// part of rest_w * height
const
int
rest
=
w_
%
block
;
pool_height_of_rest_width
(
rest
,
(
w_
-
rest
)
*
sizeof
(
float
),
max_num_regs
);
ret
();
}
class
SeqPoolCreator
:
public
JitCodeCreator
<
seq_pool_attr_t
>
{
public:
bool
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
);
}
size_t
CodeSize
(
const
seq_pool_attr_t
&
attr
)
const
override
{
return
96
+
((
attr
.
w
/
YMM_FLOAT_BLOCK
+
4
/* for rest */
)
*
4
/* load, mul and save */
+
256
)
*
8
;
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
seq_pool_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
w
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
h
,
0
);
return
make_unique
<
SeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
}
};
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
kSeqPool
,
gen
::
SeqPoolCreator
);
paddle/fluid/operators/jit/gen/seqpool.h
0 → 100644
浏览文件 @
d0e3b240
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
class
SeqPoolJitCode
:
public
JitCode
{
public:
explicit
SeqPoolJitCode
(
const
seq_pool_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
w_
(
attr
.
w
),
type_
(
attr
.
type
)
{
if
(
!
(
type_
==
SeqPoolType
::
kSum
||
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
))
{
LOG
(
FATAL
)
<<
"Only support sum pool yet "
;
}
fp_h_
[
0
]
=
1.
f
;
this
->
genCode
();
}
virtual
const
char
*
name
()
const
{
std
::
string
base
=
"SeqPoolJitCode"
;
if
(
type_
==
SeqPoolType
::
kSum
)
{
base
+=
"_Sum"
;
}
else
if
(
type_
==
SeqPoolType
::
kAvg
)
{
base
+=
"_Avg"
;
}
else
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
base
+=
"_Sqrt"
;
}
base
+=
(
"_W"
+
std
::
to_string
(
w_
));
return
base
.
c_str
();
}
void
genCode
()
override
;
protected:
template
<
typename
JMM
>
void
pool_height
(
int
w_offset
,
int
block
,
int
max_num_regs
)
{
int
offset
=
w_offset
;
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
vmovups
(
JMM
(
i
),
ptr
[
param_src
+
offset
]);
offset
+=
sizeof
(
float
)
*
block
;
}
cmp
(
reg32_int_h
,
1
);
Label
l_next_h
,
l_h_done
;
jle
(
l_h_done
,
T_NEAR
);
mov
(
reg_h_i
,
1
);
mov
(
reg_tmp
,
param_src
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
)
+
w_offset
);
L
(
l_next_h
);
{
mov
(
reg_ptr_src_i
,
reg_tmp
);
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
vmovups
(
JMM
(
i
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
// sum anyway
vaddps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
i
+
max_num_regs
));
add
(
reg_ptr_src_i
,
sizeof
(
float
)
*
block
);
}
inc
(
reg_h_i
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
));
cmp
(
reg_h_i
,
reg32_int_h
);
jl
(
l_next_h
,
T_NEAR
);
}
L
(
l_h_done
);
// save right now
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
fp_h_
));
vbroadcastss
(
JMM
(
max_num_regs
),
ptr
[
reg_tmp
]);
}
offset
=
w_offset
;
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vmulps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
max_num_regs
));
}
vmovups
(
ptr
[
param_dst
+
offset
],
JMM
(
i
));
offset
+=
sizeof
(
float
)
*
block
;
}
}
void
pool_height_of_rest_width
(
int
rest
,
int
w_offset
,
int
max_num_regs
)
{
const
int
rest_used_num_regs
=
load_rest
(
rest
,
w_offset
,
0
);
const
bool
has_block4
=
rest
/
4
>
0
;
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
cmp
(
reg32_int_h
,
1
);
Label
l_next_h
,
l_h_done
;
jle
(
l_h_done
,
T_NEAR
);
mov
(
reg_h_i
,
1
);
mov
(
reg_tmp
,
param_src
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
)
+
w_offset
);
L
(
l_next_h
);
{
int
reg_idx
=
0
;
mov
(
reg_ptr_src_i
,
reg_tmp
);
if
(
has_block4
)
{
vmovups
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
add
(
reg_ptr_src_i
,
sizeof
(
float
)
*
4
);
reg_idx
++
;
}
if
(
has_block2
)
{
vmovups
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
add
(
reg_ptr_src_i
,
sizeof
(
float
)
*
2
);
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
reg_idx
++
;
}
PADDLE_ENFORCE_EQ
(
reg_idx
,
rest_used_num_regs
,
"All heights should use same regs"
);
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
}
inc
(
reg_h_i
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
));
cmp
(
reg_h_i
,
reg32_int_h
);
jl
(
l_next_h
,
T_NEAR
);
}
L
(
l_h_done
);
// save right now
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
fp_h_
));
vbroadcastss
(
xmm_t
(
max_num_regs
),
ptr
[
reg_tmp
]);
for
(
int
i
=
0
;
i
<
rest_used_num_regs
;
++
i
)
{
vmulps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
max_num_regs
));
}
}
save_rest
(
rest
,
w_offset
);
}
// return the number of used regs, use start from reg 0
int
load_rest
(
int
rest
,
int
w_offset
,
const
int
num_shift_regs
,
const
int
reg_start
=
0
)
{
const
bool
has_block4
=
rest
/
4
>
0
;
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
int
reg_idx
=
reg_start
;
if
(
has_block4
)
{
vmovups
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param_src
+
w_offset
]);
w_offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
}
if
(
has_block2
)
{
vmovq
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param_src
+
w_offset
]);
w_offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param_src
+
w_offset
]);
reg_idx
++
;
}
return
reg_idx
;
}
// use reg start from 0
void
save_rest
(
int
rest
,
int
w_offset
,
int
reg_start
=
0
)
{
const
bool
has_block4
=
rest
/
4
>
0
;
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
int
reg_idx
=
reg_start
;
if
(
has_block4
)
{
vmovups
(
ptr
[
param_dst
+
w_offset
],
xmm_t
(
reg_idx
));
w_offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
}
if
(
has_block2
)
{
vmovq
(
ptr
[
param_dst
+
w_offset
],
xmm_t
(
reg_idx
));
w_offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
ptr
[
param_dst
+
w_offset
],
xmm_t
(
reg_idx
));
}
}
private:
float
ALIGN32_BEG
fp_h_
[
1
]
ALIGN32_END
;
int
w_
;
SeqPoolType
type_
;
reg64_t
param_src
{
abi_param1
};
reg64_t
param_dst
{
abi_param2
};
reg64_t
param_attr
{
abi_param3
};
reg64_t
reg_tmp
{
rax
};
reg32_t
reg32_int_h
{
r8d
};
reg32_t
reg32_fp_h
{
r9d
};
reg64_t
reg_h_i
{
r10
};
reg64_t
reg_ptr_src_i
{
r11
};
};
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/helper.cc
浏览文件 @
d0e3b240
...
...
@@ -26,6 +26,7 @@ namespace jit {
const
char
*
to_string
(
KernelType
kt
)
{
switch
(
kt
)
{
ONE_CASE
(
kNone
);
ONE_CASE
(
kVMul
);
ONE_CASE
(
kVAdd
);
ONE_CASE
(
kVAddRelu
);
...
...
@@ -45,12 +46,26 @@ const char* to_string(KernelType kt) {
ONE_CASE
(
kCRFDecoding
);
ONE_CASE
(
kLayerNorm
);
ONE_CASE
(
kNCHW16CMulNC
);
ONE_CASE
(
kSeqPool
);
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
kt
);
return
"NOT JITKernel"
;
}
return
nullptr
;
}
const
char
*
to_string
(
SeqPoolType
tp
)
{
switch
(
tp
)
{
ONE_CASE
(
kNonePoolType
);
ONE_CASE
(
kSum
);
ONE_CASE
(
kAvg
);
ONE_CASE
(
kSqrt
);
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
tp
);
return
"NOT PoolType"
;
}
return
nullptr
;
}
#undef ONE_CASE
KernelType
to_kerneltype
(
const
std
::
string
&
act
)
{
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
d0e3b240
...
...
@@ -119,6 +119,7 @@ typename KernelTuples::func_type Get(
}
const
char
*
to_string
(
KernelType
kt
);
const
char
*
to_string
(
SeqPoolType
kt
);
KernelType
to_kerneltype
(
const
std
::
string
&
act
);
...
...
@@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
<<
"],act_cand["
<<
to_string
(
attr
.
act_cand
)
<<
"]"
;
return
os
;
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
seq_pool_attr_t
&
attr
)
{
os
<<
"height_size["
<<
attr
.
h
<<
"],width_size["
<<
attr
.
w
<<
"],pool_type["
<<
to_string
(
attr
.
type
)
<<
"]"
;
return
os
;
}
}
// namespace jit
}
// namespace operators
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
d0e3b240
...
...
@@ -41,8 +41,16 @@ typedef enum {
kCRFDecoding
,
kLayerNorm
,
kNCHW16CMulNC
,
kSeqPool
,
}
KernelType
;
typedef
enum
{
kNonePoolType
=
0
,
kSum
=
1
,
kAvg
,
kSqrt
,
}
SeqPoolType
;
template
<
typename
T
>
struct
XYZNTuples
{
typedef
T
data_type
;
...
...
@@ -112,6 +120,21 @@ struct GRUTuples {
typedef
void
(
*
func_type
)(
gru_t
*
,
const
gru_attr_t
*
);
};
typedef
struct
seq_pool_attr_s
{
int
h
,
w
;
// h should always be the first one
SeqPoolType
type
;
seq_pool_attr_s
()
=
default
;
explicit
seq_pool_attr_s
(
int
width
,
SeqPoolType
pool_type
,
int
height
=
1
)
:
h
(
height
),
w
(
width
),
type
(
pool_type
)
{}
}
seq_pool_attr_t
;
template
<
typename
T
>
struct
SeqPoolTuples
{
typedef
T
data_type
;
typedef
seq_pool_attr_t
attr_type
;
typedef
void
(
*
func_type
)(
const
T
*
,
T
*
,
const
seq_pool_attr_t
*
);
};
template
<
typename
T
>
struct
CRFDecodingTuples
{
typedef
T
data_type
;
...
...
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
d0e3b240
...
...
@@ -42,6 +42,13 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
(
static_cast
<
int
>
(
attr
.
act_cand
)
<<
act_type_shift
);
}
template
<
>
size_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
size_t
key
=
attr
.
w
;
constexpr
int
pool_type_shift
=
3
;
return
(
key
<<
pool_type_shift
)
+
static_cast
<
int
>
(
attr
.
type
);
}
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
浏览文件 @
d0e3b240
...
...
@@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE
(
kVExp, mkl
)
USE_JITKERNEL_MORE
(
kVSigmoid, mkl
)
USE_JITKERNEL_MORE
(
kVTanh, mkl
)
USE_JITKERNEL_MORE
(
kSeqPool, mkl
)
paddle/fluid/operators/jit/more/mkl/mkl.cc
浏览文件 @
d0e3b240
...
...
@@ -72,6 +72,26 @@ void VExp<double>(const double* x, double* y, int n) {
platform
::
dynload
::
vdExp
(
n
,
x
,
y
);
}
template
<
>
void
VCopy
<
float
>
(
const
float
*
x
,
float
*
y
,
int
n
)
{
platform
::
dynload
::
cblas_scopy
(
n
,
x
,
1
,
y
,
1
);
}
template
<
>
void
VCopy
<
double
>
(
const
double
*
x
,
double
*
y
,
int
n
)
{
platform
::
dynload
::
cblas_dcopy
(
n
,
x
,
1
,
y
,
1
);
}
template
<
>
void
VAXPY
<
float
>
(
float
a
,
const
float
*
x
,
float
*
y
,
int
n
)
{
platform
::
dynload
::
cblas_saxpy
(
n
,
a
,
x
,
1
,
y
,
1
);
}
template
<
>
void
VAXPY
<
double
>
(
double
a
,
const
double
*
x
,
double
*
y
,
int
n
)
{
platform
::
dynload
::
cblas_daxpy
(
n
,
a
,
x
,
1
,
y
,
1
);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template
<
>
bool
VMulKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
...
...
@@ -103,6 +123,16 @@ bool VTanhKernel<float>::UseMe(const int& d) const {
return
d
>
7
;
}
template
<
>
bool
SeqPoolKernel
<
float
>::
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
{
return
true
;
}
template
<
>
bool
SeqPoolKernel
<
double
>::
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
{
return
true
;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \
...
...
@@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL
(
kVExp
,
VExp
);
REGISTER_MKL_KERNEL
(
kVSigmoid
,
VSigmoid
);
REGISTER_MKL_KERNEL
(
kVTanh
,
VTanh
);
REGISTER_MKL_KERNEL
(
kSeqPool
,
SeqPool
);
#undef REGISTER_MKL_KERNEL
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
d0e3b240
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <cmath>
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
...
...
@@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n);
template
<
typename
T
>
void
VExp
(
const
T
*
x
,
T
*
y
,
int
n
);
template
<
typename
T
>
void
VCopy
(
const
T
*
x
,
T
*
y
,
int
n
);
template
<
typename
T
>
void
VAXPY
(
T
a
,
const
T
*
x
,
T
*
y
,
int
n
);
template
<
typename
T
>
void
VSigmoid
(
const
T
*
x
,
T
*
y
,
int
n
)
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
...
...
@@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) {
}
}
template
<
typename
T
>
void
SeqPool
(
const
T
*
x
,
T
*
y
,
const
seq_pool_attr_t
*
attr
)
{
VCopy
<
T
>
(
x
,
y
,
attr
->
w
);
for
(
int
h
=
1
;
h
!=
attr
->
h
;
++
h
)
{
VAXPY
<
T
>
(
static_cast
<
T
>
(
1
),
x
+
h
*
attr
->
w
,
y
,
attr
->
w
);
}
if
(
attr
->
type
==
SeqPoolType
::
kAvg
||
attr
->
type
==
SeqPoolType
::
kSqrt
)
{
T
scalar
=
static_cast
<
T
>
(
1
);
if
(
attr
->
type
==
SeqPoolType
::
kAvg
)
{
scalar
=
scalar
/
static_cast
<
T
>
(
attr
->
h
);
}
else
{
scalar
=
scalar
/
std
::
sqrt
(
static_cast
<
T
>
(
attr
->
h
));
}
VScal
<
T
>
(
&
scalar
,
y
,
y
,
attr
->
w
);
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \
...
...
@@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
VTanh
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
SeqPool
,
SeqPoolTuples
);
#undef DECLARE_MKL_KERNEL
}
// namespace mkl
...
...
paddle/fluid/operators/jit/refer/CMakeLists.txt
浏览文件 @
d0e3b240
...
...
@@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2)
USE_JITKERNEL_REFER
(
kCRFDecoding
)
USE_JITKERNEL_REFER
(
kLayerNorm
)
USE_JITKERNEL_REFER
(
kNCHW16CMulNC
)
USE_JITKERNEL_REFER
(
kSeqPool
)
paddle/fluid/operators/jit/refer/refer.cc
浏览文件 @
d0e3b240
...
...
@@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
REGISTER_REFER_KERNEL
(
kNCHW16CMulNC
,
NCHW16CMulNC
);
REGISTER_REFER_KERNEL
(
kSeqPool
,
SeqPool
);
#undef REGISTER_REFER_KERNEL
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
d0e3b240
...
...
@@ -332,6 +332,28 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
}
}
template
<
typename
T
>
void
SeqPool
(
const
T
*
x
,
T
*
y
,
const
seq_pool_attr_t
*
attr
)
{
for
(
int
w
=
0
;
w
<
attr
->
w
;
++
w
)
{
const
T
*
src
=
x
+
w
;
T
*
dst
=
y
+
w
;
*
dst
=
static_cast
<
T
>
(
0
);
for
(
int
h
=
0
;
h
<
attr
->
h
;
++
h
)
{
*
dst
=
*
dst
+
*
src
;
src
+=
attr
->
w
;
}
}
if
(
attr
->
type
==
SeqPoolType
::
kAvg
||
attr
->
type
==
SeqPoolType
::
kSqrt
)
{
T
scalar
=
static_cast
<
T
>
(
1
);
if
(
attr
->
type
==
SeqPoolType
::
kAvg
)
{
scalar
=
scalar
/
static_cast
<
T
>
(
attr
->
h
);
}
else
{
scalar
=
scalar
/
std
::
sqrt
(
static_cast
<
T
>
(
attr
->
h
));
}
VScal
<
T
>
(
&
scalar
,
y
,
y
,
attr
->
w
);
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
...
...
@@ -370,6 +392,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
DECLARE_REFER_KERNEL
(
NCHW16CMulNC
,
NCHW16CMulNCTuples
);
DECLARE_REFER_KERNEL
(
SeqPool
,
SeqPoolTuples
);
#undef DECLARE_REFER_KERNEL
}
// namespace refer
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
d0e3b240
...
...
@@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
}
};
template
<
typename
T
>
struct
TestFuncWithRefer
<
jit
::
SeqPoolTuples
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
{
void
operator
()(
const
typename
jit
::
SeqPoolTuples
<
T
>::
func_type
tgt
,
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
yref
,
const
typename
jit
::
SeqPoolTuples
<
T
>::
attr_type
&
attr
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
x
.
size
()
%
yref
.
size
(),
0
);
int
w
=
yref
.
size
();
std
::
vector
<
T
>
y
(
w
);
const
T
*
x_data
=
x
.
data
();
const
T
*
yref_data
=
yref
.
data
();
T
*
y_data
=
y
.
data
();
tgt
(
x_data
,
y_data
,
&
attr
);
ExpectEQ
<
T
>
(
y_data
,
yref_data
,
w
);
}
};
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
,
typename
...
Args
>
void
TestAllImpls
(
const
typename
KernelTuples
::
attr_type
&
attr
,
Args
...
args
)
{
...
...
@@ -415,6 +433,31 @@ void TestGRUKernel() {
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestSeqPoolKernel
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
,
jit
::
SeqPoolType
::
kAvg
,
jit
::
SeqPoolType
::
kSqrt
};
for
(
auto
type
:
pool_types
)
{
for
(
int
w
:
TestSizes
())
{
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
for
(
int
h
:
TestSizes
())
{
attr
.
h
=
h
;
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
SeqPoolTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
h
*
w
),
yref
(
w
);
RandomVec
<
T
>
(
h
*
w
,
x
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
x_data
=
x
.
data
();
T
*
yref_data
=
yref
.
data
();
ref
(
x_data
,
yref_data
,
&
attr
);
VLOG
(
10
)
<<
attr
;
TestAllImpls
<
KT
,
jit
::
SeqPoolTuples
<
T
>
,
PlaceType
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
(
attr
,
x
,
yref
,
attr
);
}
}
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestNCHW16CMulNCKernel
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
...
...
@@ -569,6 +612,12 @@ TEST(JITKernel, kGRUHtPart2) {
TestGRUKernel
<
jit
::
kGRUHtPart2
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kSeqPool
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestSeqPoolKernel
<
jit
::
kSeqPool
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestSeqPoolKernel
<
jit
::
kSeqPool
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kNCHW16CMulNC
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
float
,
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
d0e3b240
...
...
@@ -51,7 +51,7 @@ math_library(pooling)
math_library
(
selected_rows_functor DEPS selected_rows math_function blas
)
math_library
(
sequence2batch
)
math_library
(
sequence_padding
)
math_library
(
sequence_pooling DEPS math_function
)
math_library
(
sequence_pooling DEPS math_function
jit_kernel_helper
)
math_library
(
sequence_scale
)
math_library
(
softmax DEPS math_function
)
...
...
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
d0e3b240
...
...
@@ -62,27 +62,19 @@ struct CUBlas<float> {
cudaDataType_t
Atype
,
int
lda
,
const
void
*
B
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
)
{
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
auto
cublas_call
=
[
&
]()
{
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
#if CUDA_VERSION >= 8000
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
platform
::
TensorCoreAvailable
()
?
"True"
:
"False"
);
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
dev_ctx
->
tensor_core_available
()
?
"True"
:
"False"
);
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
));
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
));
});
#else
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
#endif
}
};
...
...
@@ -170,32 +162,24 @@ struct CUBlas<platform::float16> {
cudaDataType_t
Btype
,
int
ldb
,
const
void
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
,
cudaDataType_t
computeType
)
{
auto
cublas_call
=
[
&
]()
{
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
bool
use_tensor_op_math
=
platform
::
TensorCoreA
vailable
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
bool
use_tensor_op_math
=
dev_ctx
->
tensor_core_a
vailable
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
#endif // CUDA_VERSION >= 9000
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
});
#else
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
#endif
}
};
...
...
@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F
,
N
);
}
else
{
#endif // CUDA_VERSION >= 8000
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
);
});
#if CUDA_VERSION >= 8000
}
...
...
@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
);
});
#endif // CUDA_VERSION >= 8000
}
...
...
@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
}
else
{
#endif // CUDA_VERSION >= 8000
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
});
#if CUDA_VERSION >= 8000
}
...
...
@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
});
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
T
*
y
)
const
{
CUBlas
<
T
>::
AXPY
(
context_
.
cublas_handle
(),
n
,
&
alpha
,
x
,
1
,
y
,
1
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
AXPY
(
handle
,
n
,
&
alpha
,
x
,
1
,
y
,
1
);
});
}
template
<
>
...
...
@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T
beta
,
T
*
C
)
const
{
cublasOperation_t
cuTransA
=
!
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBlas
<
T
>::
GEMV
(
context_
.
cublas_handle
(),
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
&
beta
,
C
,
1
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMV
(
handle
,
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
&
beta
,
C
,
1
);
});
}
template
<
>
...
...
@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010
if
(
FLAGS_enable_cublas_tensor_op_math
&&
std
::
is_same
<
T
,
float
>::
value
)
{
auto
cublas_call
=
[
&
]()
{
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
bool
use_tensor_op_math
=
platform
::
TensorCoreAvailable
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
bool
use_tensor_op_math
=
context_
.
tensor_core_available
()
;
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
context_
.
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
CUDA_R_32F
,
ldb
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
CUDA_R_32F
,
ldc
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
};
auto
&
dev_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
dev_ctx
.
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
CUDA_R_32F
,
ldb
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
CUDA_R_32F
,
ldc
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
});
}
else
{
#endif // CUDA_VERSION >= 9010
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
strideB
,
A
,
lda
,
strideA
,
&
beta
,
C
,
ldc
,
strideC
,
batchCount
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
strideB
,
A
,
lda
,
strideA
,
&
beta
,
C
,
ldc
,
strideC
,
batchCount
);
});
#if CUDA_VERSION >= 9010
}
...
...
paddle/fluid/operators/math/sequence_pooling.cc
浏览文件 @
d0e3b240
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h"
...
...
@@ -239,15 +240,33 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
last_pool
(
context
,
input
,
output
);
return
;
}
if
(
pooltype
==
"FIRST"
)
{
math
::
FirstSeqPoolFunctor
<
T
>
first_pool
;
first_pool
(
context
,
input
,
output
);
return
;
}
auto
lod
=
input
.
lod
()[
0
];
if
(
pooltype
==
"SUM"
)
{
auto
place
=
context
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
place
));
const
T
*
src
=
input
.
data
<
T
>
();
T
*
dst
=
output
->
mutable_data
<
T
>
(
place
);
jit
::
seq_pool_attr_t
attr
(
static_cast
<
int
>
(
input
.
numel
()
/
input
.
dims
()[
0
]),
jit
::
SeqPoolType
::
kSum
);
auto
seqpool
=
jit
::
Get
<
jit
::
kSeqPool
,
jit
::
SeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
attr
.
h
=
static_cast
<
int
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
seqpool
(
src
,
dst
,
&
attr
);
dst
+=
attr
.
w
;
src
+=
attr
.
h
*
attr
.
w
;
}
return
;
}
auto
&
place
=
*
context
.
eigen_device
();
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
Tensor
in_t
=
input
.
Slice
(
static_cast
<
int
>
(
lod
[
i
]),
static_cast
<
int
>
(
lod
[
i
+
1
]));
...
...
@@ -258,15 +277,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
auto
out_e
=
EigenVector
<
T
>::
Flatten
(
out_t
);
if
(
pooltype
==
"AVERAGE"
)
{
out_e
.
device
(
place
)
=
in_e
.
mean
(
Eigen
::
array
<
int
,
1
>
({{
0
}}));
}
else
if
(
pooltype
==
"SUM"
)
{
if
(
h
>
0
)
{
const
T
*
in_data
=
in_t
.
data
<
T
>
();
T
*
out_data
=
out_t
.
mutable_data
<
T
>
(
context
.
GetPlace
());
blas
.
VCOPY
(
w
,
in_data
,
out_data
);
for
(
int64_t
r
=
1
;
r
!=
h
;
++
r
)
{
blas
.
AXPY
(
w
,
1.
,
in_data
+
r
*
w
,
out_data
);
}
}
}
else
if
(
pooltype
==
"SQRT"
)
{
out_e
.
device
(
place
)
=
in_e
.
sum
(
Eigen
::
array
<
int
,
1
>
({{
0
}}))
/
std
::
sqrt
(
static_cast
<
T
>
(
h
));
...
...
paddle/fluid/operators/ngraph/ops/binary_unnary_op.h
浏览文件 @
d0e3b240
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
...
...
@@ -48,4 +47,3 @@ static void BuildUnaryNode(
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h
浏览文件 @
d0e3b240
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
...
...
@@ -58,4 +57,3 @@ std::shared_ptr<ngraph::Node> ElementwiseScalar(
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/ngraph/ops/fill_constant_op.h
浏览文件 @
d0e3b240
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
...
...
@@ -58,4 +57,3 @@ void BuildFillConstantNode(
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/ngraph/ops/mean_op.h
浏览文件 @
d0e3b240
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <functional>
...
...
@@ -65,4 +64,3 @@ void BuildMeanGradNode(
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/ngraph/ops/mul_op.h
浏览文件 @
d0e3b240
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
...
...
@@ -131,4 +130,3 @@ static void BuildMulGradNode(
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/ngraph/ops/scale_op.h
浏览文件 @
d0e3b240
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
...
...
@@ -38,4 +37,3 @@ void BuildScaleNode(
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/ngraph/ops/top_k_op.h
浏览文件 @
d0e3b240
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
...
...
@@ -48,4 +47,3 @@ void BuildTopKNode(
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/platform/cuda_helper.h
0 → 100644
浏览文件 @
d0e3b240
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum
cublasMath_t
{
CUBLAS_DEFAULT_MATH
=
0
};
#endif
namespace
paddle
{
namespace
platform
{
class
CublasHandleHolder
{
public:
CublasHandleHolder
(
cudaStream_t
stream
,
cublasMath_t
math_type
)
{
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
handle_
,
stream
));
#if CUDA_VERSION >= 9000
if
(
math_type
==
CUBLAS_TENSOR_OP_MATH
)
{
PADDLE_ENFORCE
(
dynload
::
cublasSetMathMode
(
handle_
,
CUBLAS_TENSOR_OP_MATH
));
}
#endif
}
~
CublasHandleHolder
()
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
handle_
));
}
template
<
typename
Callback
>
inline
void
Call
(
Callback
&&
callback
)
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
(
handle_
);
}
private:
DISABLE_COPY_AND_ASSIGN
(
CublasHandleHolder
);
cublasHandle_t
handle_
;
mutable
std
::
mutex
mtx_
;
};
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device_context.cc
浏览文件 @
d0e3b240
...
...
@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
cublas_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_DEFAULT_MATH
));
if
(
TensorCoreAvailable
())
{
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_TENSOR_OP_MATH
));
#endif
}
if
(
dynload
::
HasCUDNN
())
{
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
}
...
...
@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId
(
place_
.
device
);
Wait
();
WaitStreamCallback
();
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
cublas_handle_
.
reset
();
cublas_tensor_core_handle_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
...
@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return
eigen_device_
.
get
();
}
cublasHandle_t
CUDADeviceContext
::
cublas_hand
le
()
const
{
return
cublas_
handle_
;
bool
CUDADeviceContext
::
tensor_core_availab
le
()
const
{
return
cublas_
tensor_core_handle_
!=
nullptr
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
d0e3b240
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h"
...
...
@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
};
#if CUDA_VERSION >= 9000
class
ScopedCublasMathMode
{
public:
ScopedCublasMathMode
(
cublasHandle_t
handle
,
cublasMath_t
new_math_mode
)
:
handle_
(
handle
)
{
need_reset
=
false
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGetMathMode
(
handle_
,
&
old_math_mode_
),
"Failed to get old cublas math mode"
);
if
(
old_math_mode_
!=
new_math_mode
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
new_math_mode
),
"Failed to set old cublas math mode"
);
need_reset
=
true
;
}
}
~
ScopedCublasMathMode
()
{
if
(
need_reset
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
old_math_mode_
),
"Failed to set old cublas math mode"
);
}
}
private:
cublasHandle_t
handle_
;
cublasMath_t
old_math_mode_
;
bool
need_reset
;
};
#endif
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
CUDAPlace
place
);
...
...
@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */
Eigen
::
GpuDevice
*
eigen_device
()
const
;
/*! \brief Return cublas handle in the device context. */
cublasHandle_t
cublas_handle
()
const
;
/*! \brief Call cublas function safely. */
template
<
typename
Callback
>
inline
void
CublasCall
(
Callback
&&
callback
)
const
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
/*! \brief Check whether tensor core is supported */
bool
tensor_core_available
()
const
;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template
<
typename
Callback
>
inline
void
TensorCoreCublasCallIfAvailable
(
Callback
&&
callback
)
const
{
if
(
cublas_tensor_core_handle_
)
{
cublas_tensor_core_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
else
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
}
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
...
...
@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
();
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
}
...
...
@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template
<
typename
Callback
>
void
CublasCall
(
Callback
callback
,
cublasMath_t
new_math
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
cublas_mtx_
);
ScopedCublasMathMode
scoped_cublas_math
(
cublas_handle_
,
new_math
);
callback
();
}
#endif
private:
CUDAPlace
place_
;
...
...
@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
cudaStream_t
stream_
;
cublasHandle_t
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_tensor_core_handle_
;
int
compute_capability_
;
int
runtime_version_
;
...
...
@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
int
multi_process_
;
int
max_threads_per_mp_
;
mutable
std
::
mutex
mtx_
;
// StreamCallbackManager is thread-safe
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
mutable
std
::
mutex
cublas_mtx_
;
DISABLE_COPY_AND_ASSIGN
(
CUDADeviceContext
)
;
};
template
<
>
...
...
paddle/fluid/platform/device_context_test.cu
浏览文件 @
d0e3b240
...
...
@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE
(
nullptr
,
gpu_device
);
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
}
}
...
...
python/paddle/fluid/__init__.py
浏览文件 @
d0e3b240
...
...
@@ -155,7 +155,7 @@ def __bootstrap__():
'fraction_of_gpu_memory_to_use'
,
'cudnn_deterministic'
,
'enable_cublas_tensor_op_math'
,
'conv_workspace_size_limit'
,
'cudnn_exhaustive_search'
,
'memory_optimize_debug'
,
'selected_gpus'
,
'
cudnn_exhaustive_search_times'
,
'
sync_nccl_allreduce'
'sync_nccl_allreduce'
]
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录