diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h b/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h index d1d33d50a5dbb9b6d3059bc5c83acca049fb8a93..fa21a5f096611b0c6afda06615c5a26eef0d3dbb 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h @@ -70,11 +70,12 @@ namespace platform { * */ -#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ - int64_t __index__ = \ - static_cast(blockIdx.x) * blockDim.x + threadIdx.x; \ - for (index_type i = __index__; __index__ < (num); \ - __index__ += blockDim.x * gridDim.x, i = __index__) +#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ + int64_t __index__ = \ + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; \ + int64_t __stride__ = static_cast(blockDim.x) * gridDim.x; \ + for (index_type i = __index__; __index__ < (num); \ + __index__ += __stride__, i = __index__) class CublasHandleHolder { public: diff --git a/paddle/fluid/platform/device/gpu/rocm/rocm_helper.h b/paddle/fluid/platform/device/gpu/rocm/rocm_helper.h index 8bcae15d3517eb63ab2a1bacb8cca4b788911163..45eba2b1537c87c5b0181926f37e5a62091447f2 100644 --- a/paddle/fluid/platform/device/gpu/rocm/rocm_helper.h +++ b/paddle/fluid/platform/device/gpu/rocm/rocm_helper.h @@ -70,8 +70,9 @@ namespace platform { #define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ int64_t __index__ = \ static_cast(hipBlockIdx_x) * hipBlockDim_x + hipThreadIdx_x; \ + int64_t __stride__ = static_cast(hipBlockDim_x) * hipGridDim_x; \ for (index_type i = __index__; __index__ < (num); \ - __index__ += hipBlockDim_x * hipGridDim_x, i = __index__) + __index__ += __stride__, i = __index__) class CublasHandleHolder { public: diff --git a/paddle/phi/backends/gpu/cuda/cuda_helper.h b/paddle/phi/backends/gpu/cuda/cuda_helper.h index 6d33d802b1880aeb87013a4c57fd1fb0fa23c9e9..7463edc5d9ff60dbe9f8a255458af122dec75f33 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_helper.h +++ b/paddle/phi/backends/gpu/cuda/cuda_helper.h @@ -62,11 +62,12 @@ namespace gpu { * */ -#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ - int64_t __index__ = \ - static_cast(blockIdx.x) * blockDim.x + threadIdx.x; \ - for (index_type i = __index__; __index__ < (num); \ - __index__ += blockDim.x * gridDim.x, i = __index__) +#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ + int64_t __index__ = \ + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; \ + int64_t __stride__ = static_cast(blockDim.x) * gridDim.x; \ + for (index_type i = __index__; __index__ < (num); \ + __index__ += __stride__, i = __index__) } // namespace gpu } // namespace backends diff --git a/paddle/phi/backends/gpu/rocm/rocm_helper.h b/paddle/phi/backends/gpu/rocm/rocm_helper.h index e25dea28e36c101254d951d887188f6a0ce6cbd5..07fdde5a2f417a7afff1dff3a404012d2a8409b1 100644 --- a/paddle/phi/backends/gpu/rocm/rocm_helper.h +++ b/paddle/phi/backends/gpu/rocm/rocm_helper.h @@ -65,8 +65,9 @@ namespace gpu { #define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ int64_t __index__ = \ static_cast(hipBlockIdx_x) * hipBlockDim_x + hipThreadIdx_x; \ + int64_t __stride__ = static_cast(hipBlockDim_x) * hipGridDim_x; \ for (index_type i = __index__; __index__ < (num); \ - __index__ += hipBlockDim_x * hipGridDim_x, i = __index__) + __index__ += __stride__, i = __index__) } // namespace gpu } // namespace backends diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 40dfb76586189ee0eac7a4790ed08b7c93514091..7d9efa46b7a5d1fdf9e15a9c692feb819b661c3c 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -468,6 +468,397 @@ void LaunchBroadcastKernel( func); } +#ifndef PADDLE_WITH_XPU_KP +HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx( + int64_t src_idx, + const phi::Array &src_strides, + const phi::Array &dst_strides, + int rank) { + int64_t dst_idx = 0; + int64_t old_src_idx = src_idx; + for (int k = 0; k < rank; ++k) { + auto local_idx = src_idx / src_strides[k + 1]; + src_idx -= local_idx * src_strides[k + 1]; + + if (dst_strides[k] != dst_strides[k + 1]) { + dst_idx += local_idx * dst_strides[k + 1]; + } + } + return dst_idx; +} + +template +HOSTDEVICE static void ReadVecDataWithInt64Index( + const T *in, + int64_t idx, + bool need_broadcast, + const phi::Array &src_strides, + const phi::Array &dst_strides, + int rank, + int n, + phi::AlignedVector *out) { + if (IsBoundary) { + for (int i = 0; i < n; ++i) { + (*out)[i] = + in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)]; + } + } else { + if (!need_broadcast) { + phi::Load(in + idx, out); + } else { +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + (*out)[i] = + in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)]; + } + } + } +} + +template +struct ApplyFunctorWithInt64IndexHelper { + HOSTDEVICE static OutT Run(const phi::AlignedVector *ins_vec, + Functor functor, + int i); +}; + +template +struct ApplyFunctorWithInt64IndexHelper { + HOSTDEVICE static OutT Run(const phi::AlignedVector *ins_vec, + Functor functor, + int i) { + return static_cast(functor()); + } +}; + +template +struct ApplyFunctorWithInt64IndexHelper { + HOSTDEVICE static OutT Run(const phi::AlignedVector *ins_vec, + Functor functor, + int i) { + return static_cast(functor(ins_vec[0][i])); + } +}; + +template +struct ApplyFunctorWithInt64IndexHelper { + HOSTDEVICE static OutT Run(const phi::AlignedVector *ins_vec, + Functor functor, + int i) { + return static_cast(functor(ins_vec[0][i], ins_vec[1][i])); + } +}; + +template +struct ApplyFunctorWithInt64IndexHelper { + HOSTDEVICE static OutT Run(const phi::AlignedVector *ins_vec, + Functor functor, + int i) { + return static_cast( + functor(ins_vec[0][i], ins_vec[1][i], ins_vec[2][i])); + } +}; + +template +struct MaxWithOne { + static constexpr auto kValue = (N >= 1 ? N : 1); +}; + +template +__global__ void BroadcastKernelWithInt64Index( + phi::Array::kValue> ins, + OutT *out, + phi::Array, + MaxWithOne::kValue> ins_strides, + phi::Array out_strides, + phi::Array::kValue> need_broadcasts, + int rank, + Functor functor) { + int64_t numel = out_strides[0]; + int64_t idx = + (static_cast(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize; + int64_t stride = static_cast(blockDim.x) * gridDim.x * VecSize; + int64_t limit = numel - VecSize; + + phi::Array, MaxWithOne::kValue> + ins_vec; + phi::AlignedVector out_vec; + for (; idx <= limit; idx += stride) { +#pragma unroll + for (int i = 0; i < NumIns; ++i) { + ReadVecDataWithInt64Index(ins[i], + idx, + need_broadcasts[i], + out_strides, + ins_strides[i], + rank, + VecSize, + &ins_vec[i]); + } + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = ApplyFunctorWithInt64IndexHelper::Run(ins_vec.Get(), + functor, + i); + } + + phi::Store(out_vec, out + idx); + } + + if (idx < numel) { + int remain = numel - idx; // remain is always less than VecSize, therefore + // `int` is enough here +#pragma unroll + for (int i = 0; i < NumIns; ++i) { + ReadVecDataWithInt64Index(ins[i], + idx, + need_broadcasts[i], + out_strides, + ins_strides[i], + rank, + remain, + &ins_vec[i]); + } + + for (int i = 0; i < remain; ++i) { + out[idx + i] = + ApplyFunctorWithInt64IndexHelper::Run(ins_vec.Get(), + functor, + i); + } + } +} + +template +struct LaunchBroadcastKernelWithInt64IndexHelper { + static void Run(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor functor) { + PADDLE_THROW(phi::errors::PermissionDenied( + "Unreachable code branch. This may be a bug.")); + } +}; + +template +struct LaunchBroadcastKernelWithInt64IndexHelper { + static void Run(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor functor) { + phi::Array::kValue> ins_ptrs; + for (int i = 0; i < Arity; ++i) { + ins_ptrs[i] = ins[i]->data(); + } + auto *out_tensor = (*outs)[0]; + auto *out_ptr = ctx.Alloc(out_tensor); + + phi::Array, + MaxWithOne::kValue> + ins_expand_dims; + phi::Array broadcast_out_dims; + int rank; + if (Arity == 1) { + rank = ins[0]->dims().size(); + for (int i = 0; i < rank; ++i) { + broadcast_out_dims[i] = ins[0]->dims()[i]; + } + ins_expand_dims[0] = broadcast_out_dims; + } else if (Arity >= 2) { + CalculateBroadcastDims(ins[0]->dims().Get(), + ins[1]->dims().Get(), + ins[0]->dims().size(), + ins[1]->dims().size(), + axis, + ins_expand_dims[0].GetMutable(), + ins_expand_dims[1].GetMutable(), + broadcast_out_dims.GetMutable(), + &rank); + for (int i = 2; i < Arity; ++i) { + auto tmp_dims = broadcast_out_dims; + phi::Array tmp_expand_dims; + int tmp_rank; + PADDLE_ENFORCE_GE(rank, + ins[i]->dims().size(), + phi::errors::InvalidArgument( + "Unsupported reverse broadcast when the input " + "tensor number is larger than 2.")); + CalculateBroadcastDims(tmp_dims.Get(), + ins[i]->dims().Get(), + rank, + ins[i]->dims().size(), + axis, + tmp_expand_dims.GetMutable(), + ins_expand_dims[i].GetMutable(), + broadcast_out_dims.GetMutable(), + &tmp_rank); + PADDLE_ENFORCE_EQ(rank, + tmp_rank, + phi::errors::InvalidArgument( + "Wrong broadcast algorithm. This may be a bug.")); + } + } + + phi::Array, + MaxWithOne::kValue> + ins_strides; + phi::Array::kValue> need_broadcasts; + phi::Array out_strides; + const auto &out_dims = out_tensor->dims(); + if (rank <= out_dims.size()) { + out_strides = ShapeToStride(out_dims.Get(), rank); + } else { + out_strides = ShapeToStride(broadcast_out_dims.Get(), rank); + } + + for (int i = 0; i < Arity; ++i) { + ins_strides[i] = ShapeToStride(ins_expand_dims[i].Get(), rank); + need_broadcasts[i] = + !IsSameShape(out_strides.Get(), ins_strides[i].Get(), rank + 1); + } + + int64_t numel = out_strides[0]; + auto gpu_config = + phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize); + + BroadcastKernelWithInt64Index + <<>>(ins_ptrs, + out_ptr, + ins_strides, + out_strides, + need_broadcasts, + rank, + functor); + } + + private: + static void CalculateBroadcastDims(const int64_t *x_dims, + const int64_t *y_dims, + int nx, + int ny, + int axis, + int64_t *x_out_dims, + int64_t *y_out_dims, + int64_t *broadcast_out_dims, + int *length) { + PADDLE_ENFORCE_GE( + axis, 0, phi::errors::InvalidArgument("Invalid axis value: %d", axis)); + if (nx == ny) { + *length = nx; + for (int i = 0; i < nx; ++i) { + if (x_dims[i] != y_dims[i]) { + PADDLE_ENFORCE_EQ( + x_dims[i] == 1 || y_dims[i] == 1, + true, + phi::errors::InvalidArgument("Cannot broadcast input shape where " + "x_dims[%d] = %d, y_dims[%d] = %d.", + i, + x_dims[i], + i, + y_dims[i])); + } + broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i]); + x_out_dims[i] = x_dims[i]; + y_out_dims[i] = y_dims[i]; + } + } else if (nx > ny) { + *length = nx; + for (int i = nx - axis; i < ny; ++i) { + PADDLE_ENFORCE_EQ( + y_dims[i], + 1, + phi::errors::InvalidArgument( + "The trailing Y.shape[%d] should be 1 but got %d.", + i, + y_dims[i])); + } + + for (int i = 0; i < nx; ++i) { + if (i >= axis && i - axis < ny) { + if (x_dims[i] != y_dims[i - axis]) { + PADDLE_ENFORCE_EQ(x_dims[i] == 1 || y_dims[i - axis] == 1, + true, + phi::errors::InvalidArgument( + "Cannot broadcast input shape where " + "x_dims[%d] = %d, y_dims[%d] = %d.", + i, + x_dims[i], + i - axis, + y_dims[i - axis])); + } + broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i - axis]); + x_out_dims[i] = x_dims[i]; + y_out_dims[i] = y_dims[i - axis]; + } else { + broadcast_out_dims[i] = x_dims[i]; + x_out_dims[i] = x_dims[i]; + y_out_dims[i] = 1; + } + } + } else { + CalculateBroadcastDims(y_dims, + x_dims, + ny, + nx, + axis, + y_out_dims, + x_out_dims, + broadcast_out_dims, + length); + } + } + + static bool IsSameShape(const int64_t *x, const int64_t *y, int rank) { + for (int i = 0; i < rank; ++i) { + if (x[i] != y[i]) return false; + } + return true; + } + + static phi::Array ShapeToStride( + const int64_t *arr, int rank) { + phi::Array strides; + strides[rank] = 1; + for (int i = rank - 1; i >= 0; --i) { + strides[i] = strides[i + 1] * arr[i]; + } + return strides; + } +}; +#endif + template size(), NumOuts)); +#ifndef PADDLE_WITH_XPU_KP + constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3); + bool use_int64_index_kernel = + kEnabledInt64IndexKernel && + (*outs)[0]->numel() >= std::numeric_limits::max(); + if (use_int64_index_kernel) { + int vec_size = GetVecsize(ins, outs); + switch (vec_size) { + case VecSizeL: { + LaunchBroadcastKernelWithInt64IndexHelper::Run(ctx, + ins, + outs, + axis, + func); + break; + } + case VecSizeM: { + LaunchBroadcastKernelWithInt64IndexHelper::Run(ctx, + ins, + outs, + axis, + func); + break; + } + case VecSizeS: { + LaunchBroadcastKernelWithInt64IndexHelper::Run(ctx, + ins, + outs, + axis, + func); + break; + } + default: { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported vectorized size: %d!", vec_size)); + break; + } + } + return; + } +#endif + // mergedim and get vec_size const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis); phi::Array configs; diff --git a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu index 224651326d7626a088e76d092349bf917727a97a..31227e59433ea88b7642dbbada9c3aff98de4ecb 100644 --- a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu +++ b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu @@ -92,7 +92,7 @@ struct BinaryOperation { std::vector outs{output}; paddle::operators:: LaunchElementwiseCudaKernel( - dev_ctx, ins, &outs, -1, BinaryFunctor()); + dev_ctx, ins, &outs, 0, BinaryFunctor()); } }; diff --git a/tools/dockerfile/Dockerfile.release16 b/tools/dockerfile/Dockerfile.release16 index 66974f46d91e495e3d350b57c2ff5a91f95b4385..482518bf28305228d257a12be4f4123b8952a72e 100644 --- a/tools/dockerfile/Dockerfile.release16 +++ b/tools/dockerfile/Dockerfile.release16 @@ -101,8 +101,13 @@ RUN curl -s -q https://glide.sh/get | sh # Downgrade TensorRT COPY tools/dockerfile/build_scripts /build_scripts RUN bash /build_scripts/install_nccl2.sh -RUN rm -rf /build_scripts +# Older versions of patchelf limited the size of the files being processed and were fixed in this pr. +# # https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa +# # So install a newer version here. +RUN bash /build_scripts/install_patchelf.sh + +RUN rm -rf /build_scripts # git credential to skip password typing RUN git config --global credential.helper store @@ -143,13 +148,6 @@ RUN wget -q https://launchpad.net/ubuntu/+archive/primary/+sourcefiles/binutils/ RUN apt-get install libprotobuf-dev -y -# Older versions of patchelf limited the size of the files being processed and were fixed in this pr. -# https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa -# So install a newer version here. -RUN wget -q https://paddle-ci.cdn.bcebos.com/patchelf_0.10-2_amd64.deb && \ - dpkg -i patchelf_0.10-2_amd64.deb && \ - rm -rf patchelf_0.10-2_amd64.deb - # Configure OpenSSH server. c.f. https://docs.docker.com/engine/examples/running_ssh_service RUN mkdir /var/run/sshd && echo 'root:root' | chpasswd && sed -ri 's/^PermitRootLogin\s+.*/PermitRootLogin yes/' /etc/ssh/sshd_config && sed -ri 's/UsePAM yes/#UsePAM yes/g' /etc/ssh/sshd_config CMD source ~/.bashrc diff --git a/tools/dockerfile/Dockerfile.release18 b/tools/dockerfile/Dockerfile.release18 index d646f41b00d0b97b38f9ee0e0fd5af86de12697a..fe8513d662badd71c3bc1c3a0363bb7e6ebd7a8b 100644 --- a/tools/dockerfile/Dockerfile.release18 +++ b/tools/dockerfile/Dockerfile.release18 @@ -28,6 +28,10 @@ RUN apt-get update && \ # Downgrade gcc&&g++ WORKDIR /usr/bin COPY tools/dockerfile/build_scripts /build_scripts +# Older versions of patchelf limited the size of the files being processed and were fixed in this pr. +# https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa +# So install a newer version here. +RUN bash /build_scripts/install_patchelf.sh RUN bash /build_scripts/install_gcc.sh gcc82 && rm -rf /build_scripts RUN cp gcc gcc.bak && cp g++ g++.bak && rm gcc && rm g++ RUN ln -s /usr/local/gcc-8.2/bin/gcc /usr/local/bin/gcc @@ -99,14 +103,6 @@ RUN pip3.7 --no-cache-dir install pylint pytest astroid isort COPY ./python/requirements.txt /root/ RUN pip3.7 --no-cache-dir install -r /root/requirements.txt - -# Older versions of patchelf limited the size of the files being processed and were fixed in this pr. -# https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa -# So install a newer version here. -RUN wget -q https://paddle-ci.cdn.bcebos.com/patchelf_0.10-2_amd64.deb && \ - dpkg -i patchelf_0.10-2_amd64.deb && \ - rm -rf patchelf_0.10-2_amd64.deb - # Configure OpenSSH server. c.f. https://docs.docker.com/engine/examples/running_ssh_service #RUN mkdir /var/run/sshd && echo 'root:root' | chpasswd && sed -ri 's/^PermitRootLogin\s+.*/PermitRootLogin yes/' /etc/ssh/sshd_config && sed -ri 's/UsePAM yes/#UsePAM yes/g' /etc/ssh/sshd_config #CMD source ~/.bashrc diff --git a/tools/dockerfile/Dockerfile.ubuntu b/tools/dockerfile/Dockerfile.ubuntu index 7e0c3a62b1d5011346cfd04f57cbc4c0ebe7214f..b165f757325b4e43e166eb2eb9ddb633ecd896d3 100644 --- a/tools/dockerfile/Dockerfile.ubuntu +++ b/tools/dockerfile/Dockerfile.ubuntu @@ -143,9 +143,14 @@ RUN curl -s -q https://glide.sh/get | sh # See https://github.com/PaddlePaddle/Paddle/issues/10129 for details. # Downgrade TensorRT + +# Older versions of patchelf limited the size of the files being processed and were fixed in this pr. +# https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa +# So install a newer version here. COPY tools/dockerfile/build_scripts /build_scripts RUN bash /build_scripts/install_trt.sh && \ - bash /build_scripts/install_nccl2.sh + bash /build_scripts/install_nccl2.sh && \ + bash /build_scripts/install_patchelf.sh RUN rm -rf /build_scripts # git credential to skip password typing @@ -236,13 +241,6 @@ RUN wget -q https://launchpad.net/ubuntu/+archive/primary/+sourcefiles/binutils/ RUN apt-get install libprotobuf-dev -y -# Older versions of patchelf limited the size of the files being processed and were fixed in this pr. -# https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa -# So install a newer version here. -RUN wget -q https://paddle-ci.cdn.bcebos.com/patchelf_0.10-2_amd64.deb && \ - dpkg -i patchelf_0.10-2_amd64.deb && \ - rm -rf patchelf_0.10-2_amd64.deb - # Configure OpenSSH server. c.f. https://docs.docker.com/engine/examples/running_ssh_service RUN mkdir /var/run/sshd && echo 'root:root' | chpasswd && sed -ri 's/^PermitRootLogin\s+.*/PermitRootLogin yes/' /etc/ssh/sshd_config && sed -ri 's/UsePAM yes/#UsePAM yes/g' /etc/ssh/sshd_config CMD source ~/.bashrc diff --git a/tools/dockerfile/Dockerfile.ubuntu18 b/tools/dockerfile/Dockerfile.ubuntu18 index a5dba053b98b2eb8839b47907c4e9d3de7098af6..8ebfd9b8371c20a049669b8bdb8309b764335625 100644 --- a/tools/dockerfile/Dockerfile.ubuntu18 +++ b/tools/dockerfile/Dockerfile.ubuntu18 @@ -35,6 +35,10 @@ RUN apt-get update --allow-unauthenticated && \ WORKDIR /usr/bin COPY tools/dockerfile/build_scripts /build_scripts RUN bash /build_scripts/install_trt.sh +# Older versions of patchelf limited the size of the files being processed and were fixed in this pr. +# # https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa +# # So install a newer version here. +RUN bash /build_scripts/install_patchelf.sh RUN bash /build_scripts/install_gcc.sh gcc82 && rm -rf /build_scripts RUN cp gcc gcc.bak && cp g++ g++.bak && rm gcc && rm g++ RUN ln -s /usr/local/gcc-8.2/bin/gcc /usr/local/bin/gcc @@ -151,14 +155,6 @@ RUN pip3.6 --no-cache-dir install -r /root/requirements.txt && \ pip3.8 --no-cache-dir install -r /root/requirements.txt && \ pip3.9 --no-cache-dir install -r /root/requirements.txt - -# Older versions of patchelf limited the size of the files being processed and were fixed in this pr. -# https://github.com/NixOS/patchelf/commit/ba2695a8110abbc8cc6baf0eea819922ee5007fa -# So install a newer version here. -RUN wget -q https://paddle-ci.cdn.bcebos.com/patchelf_0.10-2_amd64.deb && \ - dpkg -i patchelf_0.10-2_amd64.deb && \ - rm -rf patchelf_0.10-2_amd64.deb - # Configure OpenSSH server. c.f. https://docs.docker.com/engine/examples/running_ssh_service #RUN mkdir /var/run/sshd && echo 'root:root' | chpasswd && sed -ri 's/^PermitRootLogin\s+.*/PermitRootLogin yes/' /etc/ssh/sshd_config && sed -ri 's/UsePAM yes/#UsePAM yes/g' /etc/ssh/sshd_config #CMD source ~/.bashrc diff --git a/tools/dockerfile/build_scripts/build.sh b/tools/dockerfile/build_scripts/build.sh index 92d1c12d2bc412409d21794315f66fc4ec7228bd..61bcc1f103563072d7f4fc4b4ef3823698a35a72 100644 --- a/tools/dockerfile/build_scripts/build.sh +++ b/tools/dockerfile/build_scripts/build.sh @@ -106,7 +106,7 @@ export SSL_CERT_FILE=/opt/_internal/certs.pem # tar -xzf patchelf-0.9njs2.tar.gz # (cd patchelf-0.9njs2 && ./configure && make && make install) # rm -rf patchelf-0.9njs2.tar.gz patchelf-0.9njs2 -yum install -y patchelf +sh "$MY_DIR/install_patchelf.sh" # Install latest pypi release of auditwheel #LD_LIBRARY_PATH="${ORIGINAL_LD_LIBRARY_PATH}:$(dirname ${PY35_BIN})/lib" $PY35_BIN/pip install auditwheel diff --git a/tools/dockerfile/build_scripts/install_patchelf.sh b/tools/dockerfile/build_scripts/install_patchelf.sh new file mode 100644 index 0000000000000000000000000000000000000000..9fda46e5b6f865634ec1d54836e1ec8c4cc7589d --- /dev/null +++ b/tools/dockerfile/build_scripts/install_patchelf.sh @@ -0,0 +1,29 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +TMP_DIR=patchelf_tmp + +rm -rf "$TMP_DIR" +git clone -b 0.15.0 https://github.com/NixOS/patchelf "$TMP_DIR" + +cd "$TMP_DIR" +./bootstrap.sh +./configure +make +make install + +cd .. +rm -rf "$TMP_DIR"