未验证 提交 eb6f9dd5 编写于 作者: D dzhwinter 提交者: GitHub

Feature/cuda9 cudnn7 (#10140)

* "re-commit "

* "picked up"

* "fix ci"

* "fix pdb hang up issue in cuda 9"
上级 c2620402
# A image for building paddle binaries # A image for building paddle binaries
# Use cuda devel base image for both cpu and gpu environment # Use cuda devel base image for both cpu and gpu environment
# When you modify it, please be aware of cudnn-runtime version
# When you modify it, please be aware of cudnn-runtime version
# and libcudnn.so.x in paddle/scripts/docker/build.sh # and libcudnn.so.x in paddle/scripts/docker/build.sh
FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04 FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com> MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
...@@ -24,7 +23,7 @@ ENV HOME /root ...@@ -24,7 +23,7 @@ ENV HOME /root
COPY ./paddle/scripts/docker/root/ /root/ COPY ./paddle/scripts/docker/root/ /root/
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y \ apt-get install -y --allow-downgrades \
git python-pip python-dev openssh-server bison \ git python-pip python-dev openssh-server bison \
libnccl2=2.1.2-1+cuda8.0 libnccl-dev=2.1.2-1+cuda8.0 \ libnccl2=2.1.2-1+cuda8.0 libnccl-dev=2.1.2-1+cuda8.0 \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
......
...@@ -172,6 +172,8 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) ...@@ -172,6 +172,8 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
list(APPEND CUDA_NVCC_FLAGS "-std=c++11") list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
# in cuda9, suppress cuda warning on eigen
list(APPEND CUDA_NVCC_FLAGS "-w")
# Set :expt-relaxed-constexpr to suppress Eigen warnings # Set :expt-relaxed-constexpr to suppress Eigen warnings
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
......
...@@ -22,7 +22,9 @@ else() ...@@ -22,7 +22,9 @@ else()
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10 # eigen on cuda9.1 missing header of math_funtions.hpp
# https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen
GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -344,9 +344,9 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) { ...@@ -344,9 +344,9 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
int addr = idx % 32; int addr = idx % 32;
#pragma unroll #pragma unroll
for (int k = 1; k < 32; k++) { for (int k = 1; k < 32; k++) {
// rSrc[k] = __shfl(rSrc[k], (threadIdx.x + k) % 32, 32); // rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32);
addr = __shfl(addr, (idx + 1) % 32, 32); addr = __shfl_sync(addr, (idx + 1) % 32, 32);
a[k] = __shfl(a[k], addr, 32); a[k] = __shfl_sync(a[k], addr, 32);
} }
#pragma unroll #pragma unroll
...@@ -362,8 +362,8 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) { ...@@ -362,8 +362,8 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
addr = (32 - idx) % 32; addr = (32 - idx) % 32;
#pragma unroll #pragma unroll
for (int k = 0; k < 32; k++) { for (int k = 0; k < 32; k++) {
a[k] = __shfl(a[k], addr, 32); a[k] = __shfl_sync(a[k], addr, 32);
addr = __shfl(addr, (idx + 31) % 32, 32); addr = __shfl_sync(addr, (idx + 31) % 32, 32);
} }
} }
......
...@@ -250,7 +250,7 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, ...@@ -250,7 +250,7 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
} }
} }
if (maxId[0] / 32 == warp) { if (maxId[0] / 32 == warp) {
if (__shfl(beam, (maxId[0]) % 32, 32) == maxLength) break; if (__shfl_sync(beam, (maxId[0]) % 32, 32) == maxLength) break;
} }
} }
} }
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/reduce.h> #include <thrust/reduce.h>
#include "paddle/fluid/operators/accuracy_op.h" #include "paddle/fluid/operators/accuracy_op.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/adagrad_op.h" #include "paddle/fluid/operators/adagrad_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/box_coder_op.h" #include "paddle/fluid/operators/box_coder_op.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_shift_op.h" #include "paddle/fluid/operators/conv_shift_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/edit_distance_op.h" #include "paddle/fluid/operators/edit_distance_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#ifdef __NVCC__ #ifdef __NVCC__
#include <cuda.h> #include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h> #include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_primitives.h"
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#endif #endif
...@@ -333,24 +334,12 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, ...@@ -333,24 +334,12 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
} }
} }
} }
#ifdef __NVCC__
// __shfl_down has been deprecated as of CUDA 9.0. #ifdef __NVCC__
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template <typename T> template <typename T>
__device__ T reduceSum(T val, int tid, int len) { __device__ T reduceSum(T val, int tid, int len) {
// TODO(zcd): The warp size should be taken from the // NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply. // parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently, // To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size // I use Warp-Level Parallelism and assume the Warp size
...@@ -362,7 +351,7 @@ __device__ T reduceSum(T val, int tid, int len) { ...@@ -362,7 +351,7 @@ __device__ T reduceSum(T val, int tid, int len) {
CREATE_SHFL_MASK(mask, tid < len); CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2) for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(mask, val, offset); val += platform::__shfl_down_sync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0; if (tid < warpSize) shm[tid] = 0;
...@@ -378,7 +367,7 @@ __device__ T reduceSum(T val, int tid, int len) { ...@@ -378,7 +367,7 @@ __device__ T reduceSum(T val, int tid, int len) {
if (tid < warpSize) { if (tid < warpSize) {
val = shm[tid]; val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2) for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(mask, val, offset); val += platform::__shfl_down_sync(mask, val, offset);
} }
return val; return val;
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lookup_table_op.h" #include "paddle/fluid/operators/lookup_table_op.h"
#include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/cos_sim_functor.h" #include "paddle/fluid/operators/math/cos_sim_functor.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -31,11 +32,11 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, ...@@ -31,11 +32,11 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
template <typename T> template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) { __device__ __forceinline__ T sum_single_warp(T val) {
val += __shfl_down(val, 16); val += platform::__shfl_down_sync(0, val, 16);
val += __shfl_down(val, 8); val += platform::__shfl_down_sync(0, val, 8);
val += __shfl_down(val, 4); val += platform::__shfl_down_sync(0, val, 4);
val += __shfl_down(val, 2); val += platform::__shfl_down_sync(0, val, 2);
val += __shfl_down(val, 1); val += platform::__shfl_down_sync(0, val, 1);
return val; return val;
} }
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/maxouting.h" #include "paddle/fluid/operators/math/maxouting.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h" #include "paddle/fluid/operators/math/sequence_pooling.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/sequence_scale.h" #include "paddle/fluid/operators/math/sequence_scale.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/unpooling.h" #include "paddle/fluid/operators/math/unpooling.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/vol2col.h" #include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/one_hot_op.h" #include "paddle/fluid/operators/one_hot_op.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/roi_pool_op.h" #include "paddle/fluid/operators/roi_pool_op.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/row_conv_op.h" #include "paddle/fluid/operators/row_conv_op.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -220,7 +220,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, ...@@ -220,7 +220,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
for (int offset = 16; offset > 0; for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32. offset = offset / 2) { // blockDim.x is 32.
val += __shfl_down(val, offset); val += platform::__shfl_down_sync(0, val, offset);
} }
__syncthreads(); __syncthreads();
...@@ -276,7 +276,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, ...@@ -276,7 +276,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
for (int offset = 16; offset > 0; for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32. offset = offset / 2) { // blockDim.x is 32.
val += __shfl_down(val, offset); val += platform::__shfl_down_sync(0, val, offset);
} }
__syncthreads(); __syncthreads();
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include "paddle/fluid/operators/sequence_erase_op.h" #include "paddle/fluid/operators/sequence_erase_op.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/sequence_expand_op.h" #include "paddle/fluid/operators/sequence_expand_op.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/sgd_op.h" #include "paddle/fluid/operators/sgd_op.h"
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -66,5 +66,22 @@ CUDA_ATOMIC_WRAPPER(Add, double) { ...@@ -66,5 +66,22 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
} }
#endif #endif
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) {
return __shfl_down(mask, val, delta);
}
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -155,7 +155,7 @@ EOF ...@@ -155,7 +155,7 @@ EOF
function gen_dockerfile() { function gen_dockerfile() {
# Set BASE_IMAGE according to env variables # Set BASE_IMAGE according to env variables
if [[ ${WITH_GPU} == "ON" ]]; then if [[ ${WITH_GPU} == "ON" ]]; then
BASE_IMAGE="nvidia/cuda:8.0-cudnn7-runtime-ubuntu16.04" BASE_IMAGE="nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04"
else else
BASE_IMAGE="ubuntu:16.04" BASE_IMAGE="ubuntu:16.04"
fi fi
......
...@@ -275,10 +275,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): ...@@ -275,10 +275,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
class TestBatchNormOpTraining(unittest.TestCase): class TestBatchNormOpTraining(unittest.TestCase):
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
if not np.allclose(np.array(tensor), np_array, atol=atol): np.allclose(np.array(tensor), np_array, atol=atol)
import pdb
pdb.set_trace()
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_forward_backward(self): def test_forward_backward(self):
def test_with_place(place, data_layout, shape): def test_with_place(place, data_layout, shape):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册