diff --git a/Dockerfile b/Dockerfile index c257dbfc2987323f8ed2a24dfffa8b3c15e09399..d99d3d182ef5cb4531ecaff999c048ce806eae80 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,6 @@ # A image for building paddle binaries # Use cuda devel base image for both cpu and gpu environment - -# When you modify it, please be aware of cudnn-runtime version +# When you modify it, please be aware of cudnn-runtime version # and libcudnn.so.x in paddle/scripts/docker/build.sh FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04 MAINTAINER PaddlePaddle Authors @@ -24,7 +23,7 @@ ENV HOME /root COPY ./paddle/scripts/docker/root/ /root/ RUN apt-get update && \ - apt-get install -y \ + apt-get install -y --allow-downgrades \ git python-pip python-dev openssh-server bison \ libnccl2=2.1.2-1+cuda8.0 libnccl-dev=2.1.2-1+cuda8.0 \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 7edc8637727e300539a46bc3941ace87c87903b8..b520c03a836a9e3f263ba050f151877ffe0d071d 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -172,6 +172,8 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) list(APPEND CUDA_NVCC_FLAGS "-std=c++11") list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") +# in cuda9, suppress cuda warning on eigen +list(APPEND CUDA_NVCC_FLAGS "-w") # Set :expt-relaxed-constexpr to suppress Eigen warnings list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 73d70c34dce8bedd9e62519c207e5be3dcf7dba3..edc93c2773f46ec9e0bf898557c55c93274e6a01 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -22,7 +22,9 @@ else() extern_eigen3 ${EXTERNAL_PROJECT_LOG_ARGS} GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" - GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10 + # eigen on cuda9.1 missing header of math_funtions.hpp + # https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen + GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c PREFIX ${EIGEN_SOURCE_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/cuda/src/hl_cuda_lstm.cu b/paddle/cuda/src/hl_cuda_lstm.cu index 21c0c26b6ef0420b1a719736a66eeb8114ed9680..38371366f8e2ad738974cd84a75926f72820e05f 100644 --- a/paddle/cuda/src/hl_cuda_lstm.cu +++ b/paddle/cuda/src/hl_cuda_lstm.cu @@ -344,9 +344,9 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) { int addr = idx % 32; #pragma unroll for (int k = 1; k < 32; k++) { - // rSrc[k] = __shfl(rSrc[k], (threadIdx.x + k) % 32, 32); - addr = __shfl(addr, (idx + 1) % 32, 32); - a[k] = __shfl(a[k], addr, 32); + // rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32); + addr = __shfl_sync(addr, (idx + 1) % 32, 32); + a[k] = __shfl_sync(a[k], addr, 32); } #pragma unroll @@ -362,8 +362,8 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) { addr = (32 - idx) % 32; #pragma unroll for (int k = 0; k < 32; k++) { - a[k] = __shfl(a[k], addr, 32); - addr = __shfl(addr, (idx + 31) % 32, 32); + a[k] = __shfl_sync(a[k], addr, 32); + addr = __shfl_sync(addr, (idx + 31) % 32, 32); } } diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu index fea8712a773b1524022f4bba626cf5044edebef6..94c9cceb2c37f5a9d7a1f903864f42f1a3ebbcdc 100644 --- a/paddle/cuda/src/hl_top_k.cu +++ b/paddle/cuda/src/hl_top_k.cu @@ -250,7 +250,7 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, } } if (maxId[0] / 32 == warp) { - if (__shfl(beam, (maxId[0]) % 32, 32) == maxLength) break; + if (__shfl_sync(beam, (maxId[0]) % 32, 32) == maxLength) break; } } } diff --git a/paddle/fluid/operators/accuracy_op.cu b/paddle/fluid/operators/accuracy_op.cu index 630a4a2df2ca8f6afe81be3c455d255a0693fcc3..23b48c6fdf427348879de07c671c65327d6436d7 100644 --- a/paddle/fluid/operators/accuracy_op.cu +++ b/paddle/fluid/operators/accuracy_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/accuracy_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" namespace paddle { diff --git a/paddle/fluid/operators/adagrad_op.cu b/paddle/fluid/operators/adagrad_op.cu index e798101ca6a3a44de749a2d2219295bd8911dfac..b25268786d622bc7a94117849763833e528bef48 100644 --- a/paddle/fluid/operators/adagrad_op.cu +++ b/paddle/fluid/operators/adagrad_op.cu @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/adagrad_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/box_coder_op.cu b/paddle/fluid/operators/box_coder_op.cu index 0944e9c95d4a66cc4a51751a8c70cd7a3fefaf1a..708c7a5fa96c2f9ce6a2d913ca5f30126bb6192f 100644 --- a/paddle/fluid/operators/box_coder_op.cu +++ b/paddle/fluid/operators/box_coder_op.cu @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/box_coder_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/conv_shift_op.cu b/paddle/fluid/operators/conv_shift_op.cu index 344bbade7055aa8e0aede61dd31dab246bddd169..314d33310588ed960eecaf1a0319ebf56d925c55 100644 --- a/paddle/fluid/operators/conv_shift_op.cu +++ b/paddle/fluid/operators/conv_shift_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_shift_op.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/edit_distance_op.cu b/paddle/fluid/operators/edit_distance_op.cu index 913a9145420dae7c4f6a4df10c0330636b5796b0..c25b7d2f9ec32bcef44db239de43feefd855bfe5 100644 --- a/paddle/fluid/operators/edit_distance_op.cu +++ b/paddle/fluid/operators/edit_distance_op.cu @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/edit_distance_op.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index f0362ec606c994d69f31c7a2e1e9ad0d0108b621..953aedc85064ee803ab02afd427a5a6f22096f94 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -22,6 +22,7 @@ limitations under the License. */ #ifdef __NVCC__ #include #include +#include "paddle/fluid/platform/cuda_primitives.h" constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; #endif @@ -333,24 +334,12 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, } } } -#ifdef __NVCC__ -// __shfl_down has been deprecated as of CUDA 9.0. -#if CUDA_VERSION < 9000 -template -__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { - return __shfl_down(val, delta); -} -#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; -#else -#define FULL_WARP_MASK 0xFFFFFFFF -#define CREATE_SHFL_MASK(mask, predicate) \ - mask = __ballot_sync(FULL_WARP_MASK, (predicate)) -#endif +#ifdef __NVCC__ template __device__ T reduceSum(T val, int tid, int len) { - // TODO(zcd): The warp size should be taken from the + // NOTE(zcd): The warp size should be taken from the // parameters of the GPU but not specified as 32 simply. // To make the reduceSum more efficiently, // I use Warp-Level Parallelism and assume the Warp size @@ -362,7 +351,7 @@ __device__ T reduceSum(T val, int tid, int len) { CREATE_SHFL_MASK(mask, tid < len); for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += __shfl_down_sync(mask, val, offset); + val += platform::__shfl_down_sync(mask, val, offset); if (tid < warpSize) shm[tid] = 0; @@ -378,7 +367,7 @@ __device__ T reduceSum(T val, int tid, int len) { if (tid < warpSize) { val = shm[tid]; for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += __shfl_down_sync(mask, val, offset); + val += platform::__shfl_down_sync(mask, val, offset); } return val; diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 6d81fccd2059c511f71d403229e04587e553e93d..77722c50d39003d9342afb04a61ae3aaf6b21100 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/lookup_table_op.h" #include "paddle/fluid/platform/assert.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index c0786757b34195d47c3b1cadc938f0e9fcfd6038..226a879bce5e2497c291830b99a7f84a2754263f 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/operators/math/concat.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/cos_sim_functor.cu b/paddle/fluid/operators/math/cos_sim_functor.cu index 55c1e726335dfe010e39847ac90b84cc49955360..4e6ff5ee0a449b42762748ba1a103876beee01f2 100644 --- a/paddle/fluid/operators/math/cos_sim_functor.cu +++ b/paddle/fluid/operators/math/cos_sim_functor.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/cos_sim_functor.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index da73f575f375d8a792a82bf6cf4226bab673170d..6d2ba2bd0d653ecf83f9e2abc1413ae551dc8bb7 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { @@ -31,11 +32,11 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, template __device__ __forceinline__ T sum_single_warp(T val) { - val += __shfl_down(val, 16); - val += __shfl_down(val, 8); - val += __shfl_down(val, 4); - val += __shfl_down(val, 2); - val += __shfl_down(val, 1); + val += platform::__shfl_down_sync(0, val, 16); + val += platform::__shfl_down_sync(0, val, 8); + val += platform::__shfl_down_sync(0, val, 4); + val += platform::__shfl_down_sync(0, val, 2); + val += platform::__shfl_down_sync(0, val, 1); return val; } diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index d360728484a73ce844b4a36fbffd7dc631f8e786..027e2de48d229761f12f974dc73625c8ea1b3567 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/depthwise_conv.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 657652562780ae9932a4394335bfa3c3b397bb80..da25a7d2137cfe5160e28c4e590dd5c43cd99ccf 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/gru_compute.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { diff --git a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h index 0b1034a080f15270e24622b8aaeda7f546aa90e6..d29c780dcfb1f1a3cbab25256238769d3a5ccd93 100644 --- a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { diff --git a/paddle/fluid/operators/math/im2col.cu b/paddle/fluid/operators/math/im2col.cu index 1268e21e0608000c1a8c22104912b32a973a9737..eecb233d22cea06da016b2671fd606b70eddf5a5 100644 --- a/paddle/fluid/operators/math/im2col.cu +++ b/paddle/fluid/operators/math/im2col.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/math/im2col.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/maxouting.cu b/paddle/fluid/operators/math/maxouting.cu index 1e1a6a221c71c9d9cb9fda468360cb502c5ea52f..d9a23299a4d5750fc8c7fe3e5d1f8cd94bcb9cae 100644 --- a/paddle/fluid/operators/math/maxouting.cu +++ b/paddle/fluid/operators/math/maxouting.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/maxouting.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index 274263c69c535249fceee11075c5948b1fc34358..267f8c409df301f9b1a8c68f337473198cf827f4 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/pooling.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 7b31ee8e389b94eeaa04ace52251a23933230d34..a92762c7fea865fad2c7784736cce93a8af21892 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu index 36f6402396379ab79fcbc71fd43d380227adccc4..97c2e69fe5327956fc2828781fe3a37b88cc1b99 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/fluid/operators/math/sequence_pooling.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence_pooling.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/sequence_scale.cu b/paddle/fluid/operators/math/sequence_scale.cu index 430bf13c3f8d627f2b4cc24b005f2be5a66cefac..079338c1d3dac6a9403c5871f3face9f1f8e77d2 100644 --- a/paddle/fluid/operators/math/sequence_scale.cu +++ b/paddle/fluid/operators/math/sequence_scale.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/sequence_scale.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/unpooling.cu b/paddle/fluid/operators/math/unpooling.cu index 367f343d51712d38edbb7eb50b41433258cf8c9d..c467ae8427d8f461b332eed8075631ed7e47b96e 100644 --- a/paddle/fluid/operators/math/unpooling.cu +++ b/paddle/fluid/operators/math/unpooling.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/unpooling.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/vol2col.cu b/paddle/fluid/operators/math/vol2col.cu index e0f3ef36879327c0592bb955dd800b44b228e721..28e1a752e34cf0171785a0341d8f0d8d3712fc7b 100644 --- a/paddle/fluid/operators/math/vol2col.cu +++ b/paddle/fluid/operators/math/vol2col.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/math/vol2col.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/one_hot_op.cu b/paddle/fluid/operators/one_hot_op.cu index 240ac895e2c8391322411d347384f4834995eb7c..625065692c1f32c89d9e566d00051e237ac9a3af 100644 --- a/paddle/fluid/operators/one_hot_op.cu +++ b/paddle/fluid/operators/one_hot_op.cu @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/one_hot_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" namespace paddle { diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index 0bdfee0434f6934b20083c42dd5da64f4cddf8e2..f905d690f984a20622c5fbcbcc813d888dfb19d9 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/roi_pool_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/row_conv_op.cu b/paddle/fluid/operators/row_conv_op.cu index 67083455a7579a4bbb6d9598a77b68a8375cf815..dd8e62aca47a3b34a3788a43cc0043a887af818f 100644 --- a/paddle/fluid/operators/row_conv_op.cu +++ b/paddle/fluid/operators/row_conv_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/row_conv_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { @@ -220,7 +220,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, for (int offset = 16; offset > 0; offset = offset / 2) { // blockDim.x is 32. - val += __shfl_down(val, offset); + val += platform::__shfl_down_sync(0, val, offset); } __syncthreads(); @@ -276,7 +276,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, for (int offset = 16; offset > 0; offset = offset / 2) { // blockDim.x is 32. - val += __shfl_down(val, offset); + val += platform::__shfl_down_sync(0, val, offset); } __syncthreads(); diff --git a/paddle/fluid/operators/sequence_erase_op.cu b/paddle/fluid/operators/sequence_erase_op.cu index fc9b91c351defb92246e0966b9993fd1e288aaac..3a58e47f1132cd1ac85584b2470e8c6cddcfb28a 100644 --- a/paddle/fluid/operators/sequence_erase_op.cu +++ b/paddle/fluid/operators/sequence_erase_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/sequence_erase_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index c00765e5d59af068e5682b39ebace5f3d7a62250..550677b22694085059e914678a5361d914b455bc 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/sequence_expand_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sgd_op.cu b/paddle/fluid/operators/sgd_op.cu index 9d211541c0bf729393b8190edb18e101d5e07d1a..4722be7a666d3e8f3c25c9499f88ddda835f60e3 100644 --- a/paddle/fluid/operators/sgd_op.cu +++ b/paddle/fluid/operators/sgd_op.cu @@ -14,7 +14,7 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/sgd_op.h" -#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_primitives.h similarity index 81% rename from paddle/fluid/platform/cuda_helper.h rename to paddle/fluid/platform/cuda_primitives.h index 8758af0804ae08fec6fa64d7387f197f046ce20e..46b97043ab3cf36498c34798ef63cefce2301333 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -66,5 +66,22 @@ CUDA_ATOMIC_WRAPPER(Add, double) { } #endif +// __shfl_down has been deprecated as of CUDA 9.0. +#if CUDA_VERSION < 9000 +template +__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { + return __shfl_down(val, delta); +} +#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; +#else +template +__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) { + return __shfl_down(mask, val, delta); +} +#define FULL_WARP_MASK 0xFFFFFFFF +#define CREATE_SHFL_MASK(mask, predicate) \ + mask = __ballot_sync(FULL_WARP_MASK, (predicate)) +#endif + } // namespace platform } // namespace paddle diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 94628270228b9e7fd32405bdcb5e11c163ba4791..7e00bd38487902227c3b4521db20cdbe314059be 100755 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -155,7 +155,7 @@ EOF function gen_dockerfile() { # Set BASE_IMAGE according to env variables if [[ ${WITH_GPU} == "ON" ]]; then - BASE_IMAGE="nvidia/cuda:8.0-cudnn7-runtime-ubuntu16.04" + BASE_IMAGE="nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04" else BASE_IMAGE="ubuntu:16.04" fi diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 6afb6fa6e753d3d6478313c840b158c3895b3efb..a0e78a460703778b46191b50c75e92bfbcaec411 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -275,10 +275,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): class TestBatchNormOpTraining(unittest.TestCase): def __assert_close(self, tensor, np_array, msg, atol=1e-4): - if not np.allclose(np.array(tensor), np_array, atol=atol): - import pdb - pdb.set_trace() - self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + np.allclose(np.array(tensor), np_array, atol=atol) def test_forward_backward(self): def test_with_place(place, data_layout, shape):