Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3c3df1fd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3c3df1fd
编写于
10月 08, 2018
作者:
S
shippingwang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into develop
上级
00b11c27
25262ed0
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
101 addition
and
38 deletion
+101
-38
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/cuda.cmake
cmake/cuda.cmake
+4
-1
cmake/external/eigen.cmake
cmake/external/eigen.cmake
+8
-0
cmake/flags.cmake
cmake/flags.cmake
+2
-0
paddle/fluid/framework/rw_lock.h
paddle/fluid/framework/rw_lock.h
+1
-0
paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc
paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc
+3
-0
paddle/fluid/operators/top_k_op.cu
paddle/fluid/operators/top_k_op.cu
+64
-27
paddle/fluid/operators/while_op.cc
paddle/fluid/operators/while_op.cc
+6
-4
paddle/fluid/platform/gpu_info.cc
paddle/fluid/platform/gpu_info.cc
+5
-2
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+3
-4
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+4
-0
未找到文件。
CMakeLists.txt
浏览文件 @
3c3df1fd
...
@@ -72,6 +72,7 @@ option(WITH_INFERENCE "Compile fluid inference library" ON)
...
@@ -72,6 +72,7 @@ option(WITH_INFERENCE "Compile fluid inference library" ON)
option
(
WITH_INFERENCE_API_TEST
"Test fluid inference high-level api interface"
OFF
)
option
(
WITH_INFERENCE_API_TEST
"Test fluid inference high-level api interface"
OFF
)
option
(
WITH_SYSTEM_BLAS
"Use system blas library"
OFF
)
option
(
WITH_SYSTEM_BLAS
"Use system blas library"
OFF
)
option
(
PY_VERSION
"Compile PaddlePaddle with python3 support"
${
PY_VERSION
}
)
option
(
PY_VERSION
"Compile PaddlePaddle with python3 support"
${
PY_VERSION
}
)
option
(
WITH_FAST_MATH
"Make use of fast math library"
OFF
)
# PY_VERSION
# PY_VERSION
if
(
NOT PY_VERSION
)
if
(
NOT PY_VERSION
)
...
...
cmake/cuda.cmake
浏览文件 @
3c3df1fd
...
@@ -175,7 +175,10 @@ list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
...
@@ -175,7 +175,10 @@ list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
list
(
APPEND CUDA_NVCC_FLAGS
"-Xcompiler -fPIC"
)
list
(
APPEND CUDA_NVCC_FLAGS
"-Xcompiler -fPIC"
)
endif
(
NOT WIN32
)
endif
(
NOT WIN32
)
list
(
APPEND CUDA_NVCC_FLAGS
"--use_fast_math"
)
if
(
WITH_FAST_MATH
)
# Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html
list
(
APPEND CUDA_NVCC_FLAGS
"--use_fast_math"
)
endif
()
# in cuda9, suppress cuda warning on eigen
# in cuda9, suppress cuda warning on eigen
list
(
APPEND CUDA_NVCC_FLAGS
"-w"
)
list
(
APPEND CUDA_NVCC_FLAGS
"-w"
)
# Set :expt-relaxed-constexpr to suppress Eigen warnings
# Set :expt-relaxed-constexpr to suppress Eigen warnings
...
...
cmake/external/eigen.cmake
浏览文件 @
3c3df1fd
...
@@ -3,6 +3,14 @@ INCLUDE(ExternalProject)
...
@@ -3,6 +3,14 @@ INCLUDE(ExternalProject)
SET
(
EIGEN_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/eigen3
)
SET
(
EIGEN_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/eigen3
)
SET
(
EIGEN_INCLUDE_DIR
${
EIGEN_SOURCE_DIR
}
/src/extern_eigen3
)
SET
(
EIGEN_INCLUDE_DIR
${
EIGEN_SOURCE_DIR
}
/src/extern_eigen3
)
INCLUDE_DIRECTORIES
(
${
EIGEN_INCLUDE_DIR
}
)
INCLUDE_DIRECTORIES
(
${
EIGEN_INCLUDE_DIR
}
)
if
(
NOT WITH_FAST_MATH
)
# EIGEN_FAST_MATH: https://eigen.tuxfamily.org/dox/TopicPreprocessorDirectives.html
# enables some optimizations which might affect the accuracy of the result.
# This currently enables the SSE vectorization of sin() and cos(),
# and speedups sqrt() for single precision.
# Defined to 1 by default. Define it to 0 to disable.
add_definitions
(
-DEIGEN_FAST_MATH=0
)
endif
()
if
(
WITH_AMD_GPU
)
if
(
WITH_AMD_GPU
)
ExternalProject_Add
(
ExternalProject_Add
(
...
...
cmake/flags.cmake
浏览文件 @
3c3df1fd
...
@@ -157,6 +157,8 @@ if (APPLE)
...
@@ -157,6 +157,8 @@ if (APPLE)
# On Mac OS X build fat binaries with x86_64 architectures by default.
# On Mac OS X build fat binaries with x86_64 architectures by default.
set
(
CMAKE_OSX_ARCHITECTURES
"x86_64"
CACHE STRING
"Build architectures for OSX"
FORCE
)
set
(
CMAKE_OSX_ARCHITECTURES
"x86_64"
CACHE STRING
"Build architectures for OSX"
FORCE
)
endif
()
endif
()
# On Mac OS X register class specifier is deprecated and will cause warning error on latest clang 10.0
set
(
COMMON_FLAGS -Wno-deprecated-register
)
endif
(
APPLE
)
endif
(
APPLE
)
if
(
LINUX
)
if
(
LINUX
)
...
...
paddle/fluid/framework/rw_lock.h
浏览文件 @
3c3df1fd
...
@@ -46,6 +46,7 @@ struct RWLock {
...
@@ -46,6 +46,7 @@ struct RWLock {
private:
private:
pthread_rwlock_t
lock_
;
pthread_rwlock_t
lock_
;
};
};
// TODO(paddle-dev): Support RWLock for WIN32 for correctness.
#else
#else
// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive
// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive
// In windows, rw_lock seems like a hack. Use empty object and do nothing.
// In windows, rw_lock seems like a hack. Use empty object and do nothing.
...
...
paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc
浏览文件 @
3c3df1fd
...
@@ -27,6 +27,9 @@ void SetConfig(AnalysisConfig *cfg) {
...
@@ -27,6 +27,9 @@ void SetConfig(AnalysisConfig *cfg) {
cfg
->
device
=
0
;
cfg
->
device
=
0
;
cfg
->
enable_ir_optim
=
true
;
cfg
->
enable_ir_optim
=
true
;
cfg
->
specify_input_name
=
true
;
cfg
->
specify_input_name
=
true
;
#ifdef PADDLE_WITH_MKLDNN
cfg
->
_use_mkldnn
=
true
;
#endif
}
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
...
paddle/fluid/operators/top_k_op.cu
浏览文件 @
3c3df1fd
...
@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
...
@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 3. go to the second setp, until one thread's topk value is null;
* 3. go to the second setp, until one thread's topk value is null;
* 4. go to the first setp, until get the topk value.
* 4. go to the first setp, until get the topk value.
*/
*/
template
<
typename
T
,
int
MaxLength
,
int
BlockSize
>
template
<
typename
T
,
int
MaxLength
,
int
BlockSize
>
__global__
void
KeMatrixTopK
(
T
*
output
,
int
output_stride
,
int64_t
*
indices
,
__global__
void
KeMatrixTopK
(
T
*
output
,
int
output_stride
,
int64_t
*
indices
,
const
T
*
src
,
int
lds
,
int
dim
,
int
k
)
{
const
T
*
src
,
int
lds
,
int
dim
,
int
k
,
int
grid_dim
,
int
num
)
{
__shared__
Pair
<
T
>
sh_topk
[
BlockSize
];
__shared__
Pair
<
T
>
sh_topk
[
BlockSize
];
__shared__
int
maxid
[
BlockSize
/
2
];
__shared__
int
maxid
[
BlockSize
/
2
];
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int
warp
=
threadIdx
.
x
/
32
;
const
int
warp
=
threadIdx
.
x
/
32
;
output
+=
blockIdx
.
x
*
output_stride
;
indices
+=
blockIdx
.
x
*
k
;
Pair
<
T
>
topk
[
MaxLength
];
const
int
bid
=
blockIdx
.
x
;
int
beam
=
MaxLength
;
for
(
int
i
=
bid
;
i
<
num
;
i
+=
grid_dim
)
{
Pair
<
T
>
max
;
output
+=
i
*
output_stride
;
bool
is_empty
=
false
;
indices
+=
i
*
k
;
bool
firststep
=
true
;
Pair
<
T
>
topk
[
MaxLength
];
int
beam
=
MaxLength
;
Pair
<
T
>
max
;
bool
is_empty
=
false
;
bool
firststep
=
true
;
for
(
int
k
=
0
;
k
<
MaxLength
;
k
++
)
{
topk
[
k
].
set
(
-
INFINITY
,
-
1
);
}
while
(
k
)
{
ThreadGetTopK
<
T
,
MaxLength
,
BlockSize
>
(
topk
,
&
beam
,
k
,
src
+
i
*
lds
,
&
firststep
,
&
is_empty
,
&
max
,
dim
,
tid
);
for
(
int
k
=
0
;
k
<
MaxLength
;
k
++
)
{
sh_topk
[
tid
]
=
topk
[
0
];
topk
[
k
].
set
(
-
INFINITY
,
-
1
);
BlockReduce
<
T
,
MaxLength
,
BlockSize
>
(
sh_topk
,
maxid
,
topk
,
&
output
,
&
indices
,
&
beam
,
&
k
,
tid
,
warp
);
}
}
}
while
(
k
)
{
}
ThreadGetTopK
<
T
,
MaxLength
,
BlockSize
>
(
topk
,
&
beam
,
k
,
src
+
blockIdx
.
x
*
lds
,
&
firststep
,
inline
static
int
GetDesiredBlockDim
(
int
dim
)
{
&
is_empty
,
&
max
,
dim
,
tid
);
if
(
dim
>
128
)
{
return
256
;
sh_topk
[
tid
]
=
topk
[
0
];
}
else
if
(
dim
>
64
)
{
BlockReduce
<
T
,
MaxLength
,
BlockSize
>
(
sh_topk
,
maxid
,
topk
,
&
output
,
return
128
;
&
indices
,
&
beam
,
&
k
,
tid
,
warp
);
}
else
if
(
dim
>
32
)
{
return
64
;
}
else
{
return
32
;
}
}
}
}
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template
<
typename
T
>
template
<
typename
T
>
class
TopkOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
TopkOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// NOTE: pass lds and dim same to input width.
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
// NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): refine this kernel.
// TODO(typhoonzero): refine this kernel.
dim3
threads
(
256
,
1
);
const
int
kMaxHeight
=
2048
;
dim3
grid
(
input_height
,
1
);
int
gridx
=
input_height
<
kMaxHeight
?
input_height
:
kMaxHeight
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
KeMatrixTopK
<
T
,
5
,
256
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
switch
(
GetDesiredBlockDim
(
input_width
))
{
ctx
.
device_context
())
FIXED_BLOCK_DIM
(
.
stream
()
>>>
(
KeMatrixTopK
<
T
,
5
,
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
kBlockDim
><<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
input_width
,
static_cast
<
int
>
(
k
));
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
static_cast
<
int
>
(
k
),
gridx
,
input_height
));
default:
PADDLE_THROW
(
"Error"
);
}
}
}
};
};
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/operators/while_op.cc
浏览文件 @
3c3df1fd
...
@@ -224,10 +224,12 @@ class WhileGradOp : public framework::OperatorBase {
...
@@ -224,10 +224,12 @@ class WhileGradOp : public framework::OperatorBase {
if
(
cur_scope_iter
==
step_scopes
->
rbegin
())
{
if
(
cur_scope_iter
==
step_scopes
->
rbegin
())
{
auto
*
var
=
(
*
cur_scope_iter
)
->
FindVar
(
inside_grad_name
);
auto
*
var
=
(
*
cur_scope_iter
)
->
FindVar
(
inside_grad_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find var %s"
,
inside_grad_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find var %s"
,
inside_grad_name
);
PADDLE_ENFORCE
(
var
->
IsType
<
framework
::
LoDTensorArray
>
()
||
PADDLE_ENFORCE
(
var
->
IsType
<
LoDTensor
>
(),
var
->
IsType
<
framework
::
LoDTensorArray
>
()
||
"Currently the type of var only can be LoDTensorArray "
var
->
IsType
<
LoDTensor
>
(),
"or LoDTensor."
);
"Currently the type of var only can be LoDTensorArray, "
"or LoDTensor, but the received var[%s] is %s."
,
inside_grad_name
,
var
->
Type
().
name
());
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
auto
&
inside_tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
inside_tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/platform/gpu_info.cc
浏览文件 @
3c3df1fd
...
@@ -20,8 +20,11 @@ limitations under the License. */
...
@@ -20,8 +20,11 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
DEFINE_double
(
fraction_of_gpu_memory_to_use
,
0.92
,
DEFINE_double
(
fraction_of_gpu_memory_to_use
,
0.92
,
"Default use 92% of GPU memory for PaddlePaddle,"
"Allocate a trunk of gpu memory that is this fraction of the "
"reserve the rest for page tables, etc"
);
"total gpu memory size. Future memory usage will be allocated "
"from the trunk. If the trunk doesn't have enough gpu memory, "
"additional trunks of the same size will be requested from gpu "
"until the gpu has no memory left for another trunk."
);
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
3c3df1fd
...
@@ -598,7 +598,7 @@ EOF
...
@@ -598,7 +598,7 @@ EOF
EOF
EOF
if
[[
${
WITH_GPU
}
==
"ON"
]]
;
then
if
[[
${
WITH_GPU
}
==
"ON"
]]
;
then
NCCL_DEPS
=
"apt-get install -y --allow-downgrades libnccl2=2.2.13-1+cuda
${
CUDA_MAJOR
}
libnccl-dev=2.2.13-1+cuda
${
CUDA_MAJOR
}
&&
"
NCCL_DEPS
=
"apt-get install -y --allow-downgrades libnccl2=2.2.13-1+cuda
${
CUDA_MAJOR
}
libnccl-dev=2.2.13-1+cuda
${
CUDA_MAJOR
}
|| true
"
else
else
NCCL_DEPS
=
""
NCCL_DEPS
=
""
fi
fi
...
@@ -614,9 +614,8 @@ EOF
...
@@ -614,9 +614,8 @@ EOF
cat
>>
${
PADDLE_ROOT
}
/build/Dockerfile
<<
EOF
cat
>>
${
PADDLE_ROOT
}
/build/Dockerfile
<<
EOF
ADD python/dist/*.whl /
ADD python/dist/*.whl /
# run paddle version to install python packages first
# run paddle version to install python packages first
RUN apt-get update &&
\
RUN apt-get update &&
${
NCCL_DEPS
}
${
NCCL_DEPS
}
\
RUN apt-get install -y wget python-pip python-opencv libgtk2.0-dev dmidecode python-tk && easy_install -U pip &&
\
apt-get install -y wget python-pip python-opencv libgtk2.0-dev dmidecode python-tk && easy_install -U pip &&
\
pip install /*.whl; apt-get install -f -y &&
\
pip install /*.whl; apt-get install -f -y &&
\
apt-get clean -y &&
\
apt-get clean -y &&
\
rm -f /*.whl &&
\
rm -f /*.whl &&
\
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
3c3df1fd
...
@@ -1570,6 +1570,10 @@ class DynamicRNN(object):
...
@@ -1570,6 +1570,10 @@ class DynamicRNN(object):
The dynamic RNN can mark multiple variables as its output. Use `drnn()` to
The dynamic RNN can mark multiple variables as its output. Use `drnn()` to
get the output sequence.
get the output sequence.
NOTES:
Currently it is not supported that setting is_sparse to True of any
layers within DynamicRNN.
"""
"""
BEFORE_RNN
=
0
BEFORE_RNN
=
0
IN_RNN
=
1
IN_RNN
=
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录