diff --git a/paddle/fluid/operators/detection/bbox_util.cu.h b/paddle/fluid/operators/detection/bbox_util.cu.h index 0d52fd416138273354c2ef6266eed350c91ed54c..27852d43948327c498e73a51e4151a25c31f64c3 100644 --- a/paddle/fluid/operators/detection/bbox_util.cu.h +++ b/paddle/fluid/operators/detection/bbox_util.cu.h @@ -23,6 +23,7 @@ limitations under the License. */ #ifdef __HIPCC__ #include #include "paddle/fluid/platform/miopen_helper.h" +namespace cub = hipcub; #endif #include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/math/math_function.h" @@ -64,27 +65,16 @@ static void SortDescending(const platform::CUDADeviceContext &ctx, // Determine temporary device storage requirements size_t temp_storage_bytes = 0; -#ifdef PADDLE_WITH_HIP - hipcub::DeviceRadixSort::SortPairsDescending( - nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num); -#else cub::DeviceRadixSort::SortPairsDescending( nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num); -#endif // Allocate temporary storage auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); -// Run sorting operation -#ifdef PADDLE_WITH_HIP - hipcub::DeviceRadixSort::SortPairsDescending( - d_temp_storage->ptr(), temp_storage_bytes, keys_in, keys_out, idx_in, - idx_out, num); -#else + // Run sorting operation cub::DeviceRadixSort::SortPairsDescending( d_temp_storage->ptr(), temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num); -#endif } template diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu index 4bb0f9ca67fb2bbda43a9a8ec6ea15f118d0ed0c..bc74c80e0315fac6de3ca575d53b23965adf4179 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #endif #ifdef __HIPCC__ #include +namespace cub = hipcub; #endif #include @@ -141,29 +142,17 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel { // Determine temporary device storage requirements size_t temp_storage_bytes = 0; -#ifdef PADDLE_WITH_HIP - hipcub::DeviceRadixSort::SortPairsDescending( - nullptr, temp_storage_bytes, concat_scores.data(), keys_out, idx_in, - idx_out, total_roi_num); -#else cub::DeviceRadixSort::SortPairsDescending( nullptr, temp_storage_bytes, concat_scores.data(), keys_out, idx_in, idx_out, total_roi_num); -#endif // Allocate temporary storage auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); -// Run sorting operation -// sort score to get corresponding index -#ifdef PADDLE_WITH_HIP - hipcub::DeviceRadixSort::SortPairsDescending( - d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data(), - keys_out, idx_in, idx_out, total_roi_num); -#else + // Run sorting operation + // sort score to get corresponding index cub::DeviceRadixSort::SortPairsDescending( d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data(), keys_out, idx_in, idx_out, total_roi_num); -#endif index_out_t.Resize({real_post_num}); Tensor sorted_rois; sorted_rois.mutable_data({real_post_num, kBBoxSize}, dev_ctx.GetPlace()); @@ -185,29 +174,17 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel { out_id_t.mutable_data({real_post_num}, dev_ctx.GetPlace()); // Determine temporary device storage requirements temp_storage_bytes = 0; -#ifdef PADDLE_WITH_HIP - hipcub::DeviceRadixSort::SortPairs( - nullptr, temp_storage_bytes, sorted_batch_id.data(), out_id_data, - batch_idx_in, index_out_t.data(), real_post_num); -#else cub::DeviceRadixSort::SortPairs( nullptr, temp_storage_bytes, sorted_batch_id.data(), out_id_data, batch_idx_in, index_out_t.data(), real_post_num); -#endif // Allocate temporary storage d_temp_storage = memory::Alloc(place, temp_storage_bytes); -// Run sorting operation -// sort batch_id to get corresponding index -#ifdef PADDLE_WITH_HIP - hipcub::DeviceRadixSort::SortPairs( - d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data(), - out_id_data, batch_idx_in, index_out_t.data(), real_post_num); -#else + // Run sorting operation + // sort batch_id to get corresponding index cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data(), out_id_data, batch_idx_in, index_out_t.data(), real_post_num); -#endif GPUGather(dev_ctx, sorted_rois, index_out_t, fpn_rois); diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu index 63f205947d9b5df1354245754c7d3544a6353039..cc61035309eaab31534119ab088bf537bf71c242 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #endif #ifdef __HIPCC__ #include +namespace cub = hipcub; #endif #include @@ -149,42 +150,24 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel { // Determine temporary device storage requirements size_t temp_storage_bytes = 0; -#ifdef PADDLE_WITH_HIP - hipcub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, - target_lvls_data, keys_out, - idx_in, idx_out, roi_num); -#else cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, target_lvls_data, keys_out, idx_in, idx_out, roi_num); -#endif // Allocate temporary storage auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); -// Run sorting operation -// sort target level to get corresponding index -#ifdef PADDLE_WITH_HIP - hipcub::DeviceRadixSort::SortPairs( - d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out, - idx_in, idx_out, roi_num); -#else + // Run sorting operation + // sort target level to get corresponding index cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out, idx_in, idx_out, roi_num); -#endif int* restore_idx_data = restore_index->mutable_data({roi_num, 1}, dev_ctx.GetPlace()); -// sort current index to get restore index -#ifdef PADDLE_WITH_HIP - hipcub::DeviceRadixSort::SortPairs( - d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in, - restore_idx_data, roi_num); -#else + // sort current index to get restore index cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in, restore_idx_data, roi_num); -#endif int start = 0; auto multi_rois_num = ctx.MultiOutput("MultiLevelRoIsNum"); diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index 2a550486929ec3057a49b7931bfe38a19df5b691..45d97723a3e21044daf1609b749a22ae08efad39 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #endif #ifdef __HIPCC__ #include +namespace cub = hipcub; #endif #include "paddle/fluid/operators/group_norm_op.h" @@ -46,18 +47,10 @@ enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 }; template __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { -#ifdef PADDLE_WITH_CUDA typedef cub::WarpReduce WarpReduce; -#else - typedef hipcub::WarpReduce WarpReduce; -#endif typename WarpReduce::TempStorage temp_storage; value = WarpReduce(temp_storage).Sum(value); -#ifdef PADDLE_WITH_CUDA if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value); -#else - if (hipcub::LaneId() == 0) platform::CudaAtomicAdd(sum, value); -#endif } template diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index e74f537c852f65ed4c63577cbcc6809965037697..6815fd460fa1f1969c9bf01f733f30b941fd8799 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -369,19 +369,7 @@ struct KronGradOpFunctor { for_range(func); // reduce_sum along aixs 1 -#ifdef __HIPCC__ - auto stream = dev_ctx.stream(); // it is a cuda device_context - if (dx) { - TensorReduce>( - dout_x, dx, {1}, static_cast(0), hipcub::Sum(), - IdentityFunctor(), stream); - } - if (dy) { - TensorReduce>( - dout_y, dy, {1}, static_cast(0), hipcub::Sum(), - IdentityFunctor(), stream); - } -#elif defined(__NVCC__) +#if defined(__NVCC__) || defined(__HIPCC__) auto stream = dev_ctx.stream(); // it is a cuda device_context if (dx) { TensorReduce>( diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index f93a87831f1e83a88c0b1662cb1aa7fc92e253e6..ca20efaad074d76271e6c06992dcf0cc53a8739a 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -45,12 +45,7 @@ template void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, const std::vector& reduce_dims, const paddle::framework::ExecutionContext& ctx) { -#ifdef __HIPCC__ - auto stream = ctx.cuda_device_context().stream(); - TensorReduce>( - *input, output, reduce_dims, static_cast(0), hipcub::Sum(), - IdentityFunctor(), stream); -#elif defined(__NVCC__) +#if defined(__NVCC__) || defined(__HIPCC__) auto stream = ctx.cuda_device_context().stream(); TensorReduce>( *input, output, reduce_dims, static_cast(0), cub::Sum(), diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index a738816c4006e572123844160c5d8d6c827a2554..9117b1b95ed26d03e30c59aa1f77e5de1c2b7755 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -213,12 +213,7 @@ class PoolKernel : public framework::OpKernel { if (reduce_num > 0 && adaptive) { // for adaptive_avg_pool2d && output_size == 1 -#ifdef __HIPCC__ - auto stream = dev_ctx.stream(); - TensorReduce>( - *in_x, out, reduce_dim, static_cast(0), hipcub::Sum(), - DivideFunctor(reduce_num), stream); -#elif defined(__NVCC__) +#if defined(__HIPCC__) || defined(__NVCC__) auto stream = dev_ctx.stream(); TensorReduce>( *in_x, out, reduce_dim, static_cast(0), cub::Sum(), diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu index 52ce37878c223d24ecd370c3d7b979d655bdc1e6..ca01487549fe687225b12f4e7fe75fd8d3e660f2 100644 --- a/paddle/fluid/operators/prelu_op.cu +++ b/paddle/fluid/operators/prelu_op.cu @@ -174,15 +174,9 @@ class CUDAPReluGradKernel : public framework::OpKernel { reduce_dims.push_back(i); } -#ifdef __HIPCC__ - TensorReduce>( - dalpha_tmp, dalpha, reduce_dims, static_cast(0), hipcub::Sum(), - IdentityFunctor(), stream); -#else TensorReduce>( dalpha_tmp, dalpha, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); -#endif } }; diff --git a/paddle/fluid/operators/reduce_ops/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h index dad7c848a6c8d2ae00606996cac3fa2ae4366bfa..39cce60faf3d75cc137206584135de5935ad6982 100644 --- a/paddle/fluid/operators/reduce_ops/cub_reduce.h +++ b/paddle/fluid/operators/reduce_ops/cub_reduce.h @@ -26,6 +26,7 @@ #ifdef __HIPCC__ #include +namespace cub = hipcub; #endif #include "paddle/fluid/framework/tensor.h" @@ -71,12 +72,7 @@ template ::TempStorage temp_storage; -#else __shared__ typename cub::BlockReduce::TempStorage temp_storage; -#endif int idx_x = blockIdx.x * reduce_num; int idx_y = threadIdx.x; Ty reduce_var = init; @@ -85,13 +81,8 @@ __global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer, reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); __syncthreads(); -#ifdef __HIPCC__ - reduce_var = hipcub::BlockReduce(temp_storage) - .Reduce(reduce_var, reducer); -#else reduce_var = cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); -#endif if (threadIdx.x == 0) { y[blockIdx.x] = reduce_var; @@ -107,12 +98,7 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, Array reduce_strides, Array left_dim, Array left_strides) { -#ifdef __HIPCC__ - __shared__ - typename hipcub::BlockReduce::TempStorage temp_storage; -#else __shared__ typename cub::BlockReduce::TempStorage temp_storage; -#endif Array sub_index; int left_idx = blockIdx.x; for (int i = 0; i < Rank - ReduceRank; ++i) { @@ -144,13 +130,8 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, } __syncthreads(); -#ifdef __HIPCC__ - reduce_var = hipcub::BlockReduce(temp_storage) - .Reduce(reduce_var, reducer); -#else reduce_var = cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); -#endif if (threadIdx.x == 0) { y[blockIdx.x] = reduce_var; @@ -238,32 +219,17 @@ static void TensorReduceImpl( int rank = x_strides.size(); int reduce_rank = reduce_strides.size(); if (rank == reduce_rank) { -#ifdef __HIPCC__ - hipcub::TransformInputIterator trans_x( - x_data, transformer); -#else cub::TransformInputIterator trans_x( x_data, transformer); -#endif size_t temp_storage_bytes = 0; -#ifdef __HIPCC__ - hipcub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, - reduce_num, reducer, init, stream); -#else cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, reduce_num, reducer, init, stream); -#endif framework::Tensor tmp; auto* temp_storage = tmp.mutable_data( framework::make_ddim({static_cast(temp_storage_bytes)}), place); -#ifdef __HIPCC__ - hipcub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, - y_data, reduce_num, reducer, init, stream); -#else cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, reduce_num, reducer, init, stream); -#endif return; } if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu index d4d4e04f0cb09510bffc5924f460c77a4916b7d8..cc3653fcb43a4c000d0c61c9d854965fafd59a9c 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu @@ -56,15 +56,9 @@ class ReduceMeanKernel : public framework::OpKernel { } auto stream = context.cuda_device_context().stream(); -#ifdef PADDLE_WITH_HIP - TensorReduce>( - *input, output, reduce_dims, static_cast(0), hipcub::Sum(), - DivideFunctor(reduce_num), stream); -#else TensorReduce>( *input, output, reduce_dims, static_cast(0), cub::Sum(), DivideFunctor(reduce_num), stream); -#endif } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu index 495e4c180a0a9a544d13ee0aaaa80df701439b04..219cc231a1ea7a0786026d6dcc6d63ce78e24025 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu @@ -56,25 +56,13 @@ class ReduceSumKernel : public framework::OpKernel { if (out_dtype >= 0) { framework::VisitDataTypeSmall( static_cast(out_dtype), -#ifdef __HIPCC__ - TensorReduceFunctor>( - *input, output, reduce_dims, static_cast(0.0), - hipcub::Sum(), IdentityFunctor(), stream)); -#else TensorReduceFunctor>( *input, output, reduce_dims, static_cast(0.0), cub::Sum(), IdentityFunctor(), stream)); -#endif } else { -#ifdef __HIPCC__ - TensorReduce>( - *input, output, reduce_dims, static_cast(0), hipcub::Sum(), - IdentityFunctor(), stream); -#else TensorReduce>( *input, output, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); -#endif } } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu index 0c23533aaaa1f4dacb832526ad5b83392c3444f9..220165ac1bd4f6a80a2f3c0b21f5423352982588 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu @@ -20,6 +20,7 @@ limitations under the License. */ #ifdef __HIPCC__ #include +namespace cub = hipcub; #endif #include "paddle/fluid/operators/math.h" @@ -31,11 +32,7 @@ namespace operators { using LoDTensor = framework::LoDTensor; template -#ifdef __HIPCC__ -using BlockReduce = hipcub::BlockReduce; -#else using BlockReduce = cub::BlockReduce; -#endif template using BlockReduceTempStorage = typename BlockReduce::TempStorage; @@ -57,13 +54,8 @@ __global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod, T ele = in_data[start + tid]; max_ele = max_ele > ele ? max_ele : ele; } -#ifdef __HIPCC__ - max_ele = - BlockReduce(temp_storage).Reduce(max_ele, hipcub::Max()); -#else max_ele = BlockReduce(temp_storage).Reduce(max_ele, cub::Max()); -#endif if (threadIdx.x == 0) { shared_max_data = max_ele; } @@ -75,13 +67,8 @@ __global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod, T ele = in_data[start + tid]; sum_data += real_exp(ele - shared_max_data); } -#ifdef __HIPCC__ - sum_data = - BlockReduce(temp_storage).Reduce(sum_data, hipcub::Sum()); -#else sum_data = BlockReduce(temp_storage).Reduce(sum_data, cub::Sum()); -#endif if (threadIdx.x == 0) { shared_sum_data = sum_data; } @@ -116,12 +103,7 @@ __global__ void sequence_softmax_grad_kernel(const T *softmax_grad_data, T s_d = softmax_data[idx]; result += s_g_d * s_d; } -#ifdef __HIPCC__ - result = - BlockReduce(temp_storage).Reduce(result, hipcub::Sum()); -#else result = BlockReduce(temp_storage).Reduce(result, cub::Sum()); -#endif if (threadIdx.x == 0) { shared_data = result; } diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu index a2d51e9c5bde7b470d2ee073f3c44f8a2313cc95..ea328361ded75ade9228fffe4dee0b4c6f0fc3e6 100644 --- a/paddle/fluid/operators/trace_op.cu +++ b/paddle/fluid/operators/trace_op.cu @@ -43,15 +43,9 @@ class TraceCUDAKernel : public framework::OpKernel { auto stream = context.cuda_device_context().stream(); std::vector reduce_dims; reduce_dims.push_back(out->dims().size()); -#ifdef __HIPCC__ - TensorReduce>( - diag, out, reduce_dims, static_cast(0), hipcub::Sum(), - IdentityFunctor(), stream); -#else TensorReduce>( diag, out, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); -#endif } } }; diff --git a/paddle/fluid/platform/gpu_launch_config.h b/paddle/fluid/platform/gpu_launch_config.h index 422e5a987b6ad01343a8f446f097141f23e20b8c..e94bf6d89daa5ca6e7921af02b95edc70dd1765a 100644 --- a/paddle/fluid/platform/gpu_launch_config.h +++ b/paddle/fluid/platform/gpu_launch_config.h @@ -41,7 +41,11 @@ struct GpuLaunchConfig { inline GpuLaunchConfig GetGpuLaunchConfig1D( const platform::CUDADeviceContext& context, int element_count, +#ifdef PADDLE_WITH_HIP + int max_threads = 256) { +#else int max_threads = 1024) { +#endif PADDLE_ENFORCE_GT(element_count, 0, platform::errors::InvalidArgument( "element count should be greater than 0," diff --git a/tools/dockerfile/Dockerfile.rocm b/tools/dockerfile/Dockerfile.rocm index 6ae6b8963b7f5344856a505beed5c879b36925f8..eab4ef07c877897d0acaa4311182f34acc8ebc29 100644 --- a/tools/dockerfile/Dockerfile.rocm +++ b/tools/dockerfile/Dockerfile.rocm @@ -1,16 +1,16 @@ # A image for building paddle binaries # Use rocm-terminal base image for both rocm environment # When you modify it, please be aware of rocm version -# -# Build: ROCM 3.9 +# +# Build: ROCM 4.0.1 # cd Paddle/tools/dockerfile # docker build -f Dockerfile.rocm \ -# --build-arg ROCM_VERSION=3.9 \ -# -t paddlepaddle/paddle-centos-rocm39-dev:latest . -# +# --build-arg ROCM_VERSION=4.0.1 \ +# -t paddlepaddle/paddle-centos-rocm401-dev:latest . +# # docker run -it --device=/dev/kfd --device=/dev/dri \ # --security-opt seccomp=unconfined --group-add video \ -# paddlepaddle/paddle-centos-rocm39-dev:latest /bin/bash +# paddlepaddle/paddle-centos-rocm401-dev:latest /bin/bash FROM centos:7.8.2003 MAINTAINER PaddlePaddle Authors @@ -21,7 +21,8 @@ ENV LANGUAGE en_US.UTF-8 RUN yum install -y epel-release deltarpm sudo openssh-server gettext-devel sqlite-devel \ zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool xz graphviz wget curl-devel \ - make bzip2 git patch unzip bison yasm diffutils automake which file kernel-headers kernel-devel + make bzip2 git patch unzip bison yasm diffutils automake which file kernel-headers kernel-devel \ + net-tools numactl-devel chrpath # Install devtoolset-7 RUN yum install -y yum-utils centos-release-scl && \ @@ -70,7 +71,7 @@ RUN cd /opt && wget -q https://paddle-ci.gz.bcebos.com/git-2.17.1.tar.gz && \ make -j8 && make install && \ cd .. && rm -rf git-2.17.1.tar.gz && rm -rf git-2.17.1 -ENV GOROOT=/usr/local/go +ENV GOROOT=/usr/local/go ENV GOPATH=/root/gopath ENV PATH=${GOROOT}/bin:${GOPATH}/bin:${PATH} @@ -82,7 +83,7 @@ RUN wget --no-check-certificate -qO- https://storage.googleapis.com/golang/go1.8 mkdir /root/gopath/src # protobuf 3.6.1 -RUN cd /opt && wget -q --no-check-certificate https://paddle-ci.cdn.bcebos.com/protobuf-cpp-3.6.1.tar.gz && \ +RUN cd /opt && wget -q --no-check-certificate https://paddle-ci.cdn.bcebos.com/protobuf-cpp-3.6.1.tar.gz && \ tar xzf protobuf-cpp-3.6.1.tar.gz && \ cd protobuf-3.6.1 && ./configure && make -j4 && make install && \ cd .. && rm -f protobuf-cpp-3.6.1.tar.gz && rm -rf protobuf-3.6.1 @@ -91,28 +92,34 @@ RUN cd /opt && wget -q --no-check-certificate https://paddle-ci.cdn.bcebos.com/p RUN cd /opt && wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && chmod +x Miniconda3-latest-Linux-x86_64.sh RUN mkdir /opt/conda && ./Miniconda3-latest-Linux-x86_64.sh -b -f -p "/opt/conda" && rm -rf Miniconda3-latest-Linux-x86_64.sh ENV PATH=/opt/conda/bin:${PATH} -RUN conda init bash && \ - conda create -n python2.7 python=2.7 && \ - conda create -n python3.7 python=3.7 +RUN conda init bash && conda install -n base jupyter -# install paddle requirement +# install Paddle requirement RUN wget https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/python/requirements.txt -O /root/requirements.txt RUN /opt/conda/bin/pip install -r /root/requirements.txt && \ - /opt/conda/envs/python2.7/bin/pip install -r /root/requirements.txt && \ - /opt/conda/envs/python3.7/bin/pip install -r /root/requirements.txt && \ rm -rf /root/requirements.txt RUN wget https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/python/unittest_py/requirements.txt -O /root/requirements.txt -RUN /opt/conda/bin/pip install -r /root/requirements.txt && \ - /opt/conda/envs/python2.7/bin/pip install -r /root/requirements.txt && \ - /opt/conda/envs/python3.7/bin/pip install -r /root/requirements.txt && \ - rm -rf /root/requirements.txt +RUN /opt/conda/bin/pip install -r /root/requirements.txt && rm -rf /root/requirements.txt + +# install PaddleClas requirement +RUN wget https://raw.githubusercontent.com/PaddlePaddle/PaddleClas/develop/requirements.txt -O /root/requirements.txt +RUN /opt/conda/bin/pip install -r /root/requirements.txt && rm -rf /root/requirements.txt + +# install PaddleDetection requirement +RUN wget https://raw.githubusercontent.com/PaddlePaddle/PaddleDetection/develop/requirements.txt -O /root/requirements.txt +RUN /opt/conda/bin/pip install -r /root/requirements.txt && rm -rf /root/requirements.txt # configure ssh RUN sed -i "s/^#PermitRootLogin/PermitRootLogin/" /etc/ssh/sshd_config && \ sed -i "s/^#PubkeyAuthentication/PubkeyAuthentication/" /etc/ssh/sshd_config && \ sed -i "s/^#RSAAuthentication/RSAAuthentication/" /etc/ssh/sshd_config +# clang-format 3.8 +RUN wget https://copr.fedorainfracloud.org/coprs/alonid/llvm-3.8.0/repo/epel-7/alonid-llvm-3.8.0-epel-7.repo -P /etc/yum.repos.d/ +RUN yum install -y clang-3.8.0 +ENV PATH=/opt/llvm-3.8.0/bin:${PATH} + # patchelf RUN yum install -y patchelf && \ yum clean all && \