From eb6f9dd5de3f3b2e72067fa6efb49a97057e46b0 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 30 Apr 2018 20:57:44 +0800 Subject: [PATCH] Feature/cuda9 cudnn7 (#10140) * "re-commit " * "picked up" * "fix ci" * "fix pdb hang up issue in cuda 9" --- Dockerfile | 5 ++--- cmake/cuda.cmake | 2 ++ cmake/external/eigen.cmake | 4 +++- paddle/cuda/src/hl_cuda_lstm.cu | 10 ++++----- paddle/cuda/src/hl_top_k.cu | 2 +- paddle/fluid/operators/accuracy_op.cu | 2 +- paddle/fluid/operators/adagrad_op.cu | 2 +- paddle/fluid/operators/box_coder_op.cu | 2 +- paddle/fluid/operators/conv_shift_op.cu | 2 +- paddle/fluid/operators/edit_distance_op.cu | 2 +- .../fluid/operators/elementwise_op_function.h | 21 +++++-------------- paddle/fluid/operators/lookup_table_op.cu | 2 +- paddle/fluid/operators/math/concat.cu | 2 +- .../fluid/operators/math/cos_sim_functor.cu | 2 +- paddle/fluid/operators/math/cross_entropy.cu | 11 +++++----- paddle/fluid/operators/math/depthwise_conv.cu | 2 +- .../operators/math/detail/gru_gpu_kernel.h | 2 +- .../operators/math/detail/lstm_gpu_kernel.h | 2 +- paddle/fluid/operators/math/im2col.cu | 2 +- paddle/fluid/operators/math/maxouting.cu | 2 +- paddle/fluid/operators/math/pooling.cu | 2 +- .../operators/math/selected_rows_functor.cu | 2 +- .../fluid/operators/math/sequence_pooling.cu | 2 +- paddle/fluid/operators/math/sequence_scale.cu | 2 +- paddle/fluid/operators/math/unpooling.cu | 2 +- paddle/fluid/operators/math/vol2col.cu | 2 +- paddle/fluid/operators/one_hot_op.cu | 2 +- paddle/fluid/operators/roi_pool_op.cu | 2 +- paddle/fluid/operators/row_conv_op.cu | 6 +++--- paddle/fluid/operators/sequence_erase_op.cu | 2 +- paddle/fluid/operators/sequence_expand_op.cu | 2 +- paddle/fluid/operators/sgd_op.cu | 2 +- .../{cuda_helper.h => cuda_primitives.h} | 17 +++++++++++++++ paddle/scripts/docker/build.sh | 2 +- .../tests/unittests/test_batch_norm_op.py | 5 +---- 35 files changed, 70 insertions(+), 63 deletions(-) rename paddle/fluid/platform/{cuda_helper.h => cuda_primitives.h} (81%) diff --git a/Dockerfile b/Dockerfile index c257dbfc29..d99d3d182e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,6 @@ # A image for building paddle binaries # Use cuda devel base image for both cpu and gpu environment - -# When you modify it, please be aware of cudnn-runtime version +# When you modify it, please be aware of cudnn-runtime version # and libcudnn.so.x in paddle/scripts/docker/build.sh FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04 MAINTAINER PaddlePaddle Authors @@ -24,7 +23,7 @@ ENV HOME /root COPY ./paddle/scripts/docker/root/ /root/ RUN apt-get update && \ - apt-get install -y \ + apt-get install -y --allow-downgrades \ git python-pip python-dev openssh-server bison \ libnccl2=2.1.2-1+cuda8.0 libnccl-dev=2.1.2-1+cuda8.0 \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 7edc863772..b520c03a83 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -172,6 +172,8 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) list(APPEND CUDA_NVCC_FLAGS "-std=c++11") list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") +# in cuda9, suppress cuda warning on eigen +list(APPEND CUDA_NVCC_FLAGS "-w") # Set :expt-relaxed-constexpr to suppress Eigen warnings list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 73d70c34dc..edc93c2773 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -22,7 +22,9 @@ else() extern_eigen3 ${EXTERNAL_PROJECT_LOG_ARGS} GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" - GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10 + # eigen on cuda9.1 missing header of math_funtions.hpp + # https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen + GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c PREFIX ${EIGEN_SOURCE_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/cuda/src/hl_cuda_lstm.cu b/paddle/cuda/src/hl_cuda_lstm.cu index 21c0c26b6e..38371366f8 100644 --- a/paddle/cuda/src/hl_cuda_lstm.cu +++ b/paddle/cuda/src/hl_cuda_lstm.cu @@ -344,9 +344,9 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) { int addr = idx % 32; #pragma unroll for (int k = 1; k < 32; k++) { - // rSrc[k] = __shfl(rSrc[k], (threadIdx.x + k) % 32, 32); - addr = __shfl(addr, (idx + 1) % 32, 32); - a[k] = __shfl(a[k], addr, 32); + // rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32); + addr = __shfl_sync(addr, (idx + 1) % 32, 32); + a[k] = __shfl_sync(a[k], addr, 32); } #pragma unroll @@ -362,8 +362,8 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) { addr = (32 - idx) % 32; #pragma unroll for (int k = 0; k < 32; k++) { - a[k] = __shfl(a[k], addr, 32); - addr = __shfl(addr, (idx + 31) % 32, 32); + a[k] = __shfl_sync(a[k], addr, 32); + addr = __shfl_sync(addr, (idx + 31) % 32, 32); } } diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu index fea8712a77..94c9cceb2c 100644 --- a/paddle/cuda/src/hl_top_k.cu +++ b/paddle/cuda/src/hl_top_k.cu @@ -250,7 +250,7 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, } } if (maxId[0] / 32 == warp) { - if (__shfl(beam, (maxId[0]) % 32, 32) == maxLength) break; + if (__shfl_sync(beam, (maxId[0]) % 32, 32) == maxLength) break; } } } diff --git a/paddle/fluid/operators/accuracy_op.cu b/paddle/fluid/operators/accuracy_op.cu index 630a4a2df2..23b48c6fdf 100644 --- a/paddle/fluid/operators/accuracy_op.cu +++ b/paddle/fluid/operators/accuracy_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/accuracy_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" namespace paddle { diff --git a/paddle/fluid/operators/adagrad_op.cu b/paddle/fluid/operators/adagrad_op.cu index e798101ca6..b25268786d 100644 --- a/paddle/fluid/operators/adagrad_op.cu +++ b/paddle/fluid/operators/adagrad_op.cu @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/adagrad_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/box_coder_op.cu b/paddle/fluid/operators/box_coder_op.cu index 0944e9c95d..708c7a5fa9 100644 --- a/paddle/fluid/operators/box_coder_op.cu +++ b/paddle/fluid/operators/box_coder_op.cu @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/box_coder_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/conv_shift_op.cu b/paddle/fluid/operators/conv_shift_op.cu index 344bbade70..314d333105 100644 --- a/paddle/fluid/operators/conv_shift_op.cu +++ b/paddle/fluid/operators/conv_shift_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_shift_op.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/edit_distance_op.cu b/paddle/fluid/operators/edit_distance_op.cu index 913a914542..c25b7d2f9e 100644 --- a/paddle/fluid/operators/edit_distance_op.cu +++ b/paddle/fluid/operators/edit_distance_op.cu @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/edit_distance_op.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index f0362ec606..953aedc850 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -22,6 +22,7 @@ limitations under the License. */ #ifdef __NVCC__ #include #include +#include "paddle/fluid/platform/cuda_primitives.h" constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; #endif @@ -333,24 +334,12 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, } } } -#ifdef __NVCC__ -// __shfl_down has been deprecated as of CUDA 9.0. -#if CUDA_VERSION < 9000 -template -__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { - return __shfl_down(val, delta); -} -#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; -#else -#define FULL_WARP_MASK 0xFFFFFFFF -#define CREATE_SHFL_MASK(mask, predicate) \ - mask = __ballot_sync(FULL_WARP_MASK, (predicate)) -#endif +#ifdef __NVCC__ template __device__ T reduceSum(T val, int tid, int len) { - // TODO(zcd): The warp size should be taken from the + // NOTE(zcd): The warp size should be taken from the // parameters of the GPU but not specified as 32 simply. // To make the reduceSum more efficiently, // I use Warp-Level Parallelism and assume the Warp size @@ -362,7 +351,7 @@ __device__ T reduceSum(T val, int tid, int len) { CREATE_SHFL_MASK(mask, tid < len); for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += __shfl_down_sync(mask, val, offset); + val += platform::__shfl_down_sync(mask, val, offset); if (tid < warpSize) shm[tid] = 0; @@ -378,7 +367,7 @@ __device__ T reduceSum(T val, int tid, int len) { if (tid < warpSize) { val = shm[tid]; for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += __shfl_down_sync(mask, val, offset); + val += platform::__shfl_down_sync(mask, val, offset); } return val; diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 6d81fccd20..77722c50d3 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/lookup_table_op.h" #include "paddle/fluid/platform/assert.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index c0786757b3..226a879bce 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/operators/math/concat.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/cos_sim_functor.cu b/paddle/fluid/operators/math/cos_sim_functor.cu index 55c1e72633..4e6ff5ee0a 100644 --- a/paddle/fluid/operators/math/cos_sim_functor.cu +++ b/paddle/fluid/operators/math/cos_sim_functor.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/cos_sim_functor.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index da73f575f3..6d2ba2bd0d 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { @@ -31,11 +32,11 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, template __device__ __forceinline__ T sum_single_warp(T val) { - val += __shfl_down(val, 16); - val += __shfl_down(val, 8); - val += __shfl_down(val, 4); - val += __shfl_down(val, 2); - val += __shfl_down(val, 1); + val += platform::__shfl_down_sync(0, val, 16); + val += platform::__shfl_down_sync(0, val, 8); + val += platform::__shfl_down_sync(0, val, 4); + val += platform::__shfl_down_sync(0, val, 2); + val += platform::__shfl_down_sync(0, val, 1); return val; } diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index d360728484..027e2de48d 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/depthwise_conv.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 6576525627..da25a7d213 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/gru_compute.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { diff --git a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h index 0b1034a080..d29c780dcf 100644 --- a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { diff --git a/paddle/fluid/operators/math/im2col.cu b/paddle/fluid/operators/math/im2col.cu index 1268e21e06..eecb233d22 100644 --- a/paddle/fluid/operators/math/im2col.cu +++ b/paddle/fluid/operators/math/im2col.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/math/im2col.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/maxouting.cu b/paddle/fluid/operators/math/maxouting.cu index 1e1a6a221c..d9a23299a4 100644 --- a/paddle/fluid/operators/math/maxouting.cu +++ b/paddle/fluid/operators/math/maxouting.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/maxouting.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index 274263c69c..267f8c409d 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/pooling.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 7b31ee8e38..a92762c7fe 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu index 36f6402396..97c2e69fe5 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/fluid/operators/math/sequence_pooling.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence_pooling.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/sequence_scale.cu b/paddle/fluid/operators/math/sequence_scale.cu index 430bf13c3f..079338c1d3 100644 --- a/paddle/fluid/operators/math/sequence_scale.cu +++ b/paddle/fluid/operators/math/sequence_scale.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/sequence_scale.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/unpooling.cu b/paddle/fluid/operators/math/unpooling.cu index 367f343d51..c467ae8427 100644 --- a/paddle/fluid/operators/math/unpooling.cu +++ b/paddle/fluid/operators/math/unpooling.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/unpooling.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/vol2col.cu b/paddle/fluid/operators/math/vol2col.cu index e0f3ef3687..28e1a752e3 100644 --- a/paddle/fluid/operators/math/vol2col.cu +++ b/paddle/fluid/operators/math/vol2col.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/math/vol2col.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/one_hot_op.cu b/paddle/fluid/operators/one_hot_op.cu index 240ac895e2..625065692c 100644 --- a/paddle/fluid/operators/one_hot_op.cu +++ b/paddle/fluid/operators/one_hot_op.cu @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/one_hot_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" namespace paddle { diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index 0bdfee0434..f905d690f9 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/roi_pool_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/row_conv_op.cu b/paddle/fluid/operators/row_conv_op.cu index 67083455a7..dd8e62aca4 100644 --- a/paddle/fluid/operators/row_conv_op.cu +++ b/paddle/fluid/operators/row_conv_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/row_conv_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { @@ -220,7 +220,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, for (int offset = 16; offset > 0; offset = offset / 2) { // blockDim.x is 32. - val += __shfl_down(val, offset); + val += platform::__shfl_down_sync(0, val, offset); } __syncthreads(); @@ -276,7 +276,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, for (int offset = 16; offset > 0; offset = offset / 2) { // blockDim.x is 32. - val += __shfl_down(val, offset); + val += platform::__shfl_down_sync(0, val, offset); } __syncthreads(); diff --git a/paddle/fluid/operators/sequence_erase_op.cu b/paddle/fluid/operators/sequence_erase_op.cu index fc9b91c351..3a58e47f11 100644 --- a/paddle/fluid/operators/sequence_erase_op.cu +++ b/paddle/fluid/operators/sequence_erase_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/sequence_erase_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index c00765e5d5..550677b226 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/sequence_expand_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sgd_op.cu b/paddle/fluid/operators/sgd_op.cu index 9d211541c0..4722be7a66 100644 --- a/paddle/fluid/operators/sgd_op.cu +++ b/paddle/fluid/operators/sgd_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/sgd_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_primitives.h similarity index 81% rename from paddle/fluid/platform/cuda_helper.h rename to paddle/fluid/platform/cuda_primitives.h index 8758af0804..46b97043ab 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -66,5 +66,22 @@ CUDA_ATOMIC_WRAPPER(Add, double) { } #endif +// __shfl_down has been deprecated as of CUDA 9.0. +#if CUDA_VERSION < 9000 +template +__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { + return __shfl_down(val, delta); +} +#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; +#else +template +__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) { + return __shfl_down(mask, val, delta); +} +#define FULL_WARP_MASK 0xFFFFFFFF +#define CREATE_SHFL_MASK(mask, predicate) \ + mask = __ballot_sync(FULL_WARP_MASK, (predicate)) +#endif + } // namespace platform } // namespace paddle diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 9462827022..7e00bd3848 100755 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -155,7 +155,7 @@ EOF function gen_dockerfile() { # Set BASE_IMAGE according to env variables if [[ ${WITH_GPU} == "ON" ]]; then - BASE_IMAGE="nvidia/cuda:8.0-cudnn7-runtime-ubuntu16.04" + BASE_IMAGE="nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04" else BASE_IMAGE="ubuntu:16.04" fi diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 6afb6fa6e7..a0e78a4607 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -275,10 +275,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): class TestBatchNormOpTraining(unittest.TestCase): def __assert_close(self, tensor, np_array, msg, atol=1e-4): - if not np.allclose(np.array(tensor), np_array, atol=atol): - import pdb - pdb.set_trace() - self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + np.allclose(np.array(tensor), np_array, atol=atol) def test_forward_backward(self): def test_with_place(place, data_layout, shape): -- GitLab