From 9b016c7cb7dedfc73b00dab5887c10cf70ee7636 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Mon, 1 Mar 2021 11:58:02 +0800 Subject: [PATCH] [ROCM] update fluid operators for rocm (part2), test=develop (#31211) --- .../operators/distributed_ops/CMakeLists.txt | 2 +- .../operators/distributed_ops/allreduce_op.h | 8 ++- .../distributed_ops/broadcast_op.cu.cc | 8 ++- .../distributed_ops/ref_by_trainer_id_op.h | 2 +- paddle/fluid/operators/kron_op.h | 20 +++++-- paddle/fluid/operators/matmul_v2_op.h | 9 +++- paddle/fluid/operators/prelu_op.cu | 8 ++- .../fluid/operators/reduce_ops/CMakeLists.txt | 6 ++- .../fluid/operators/reduce_ops/cub_reduce.h | 52 +++++++++++++++++-- .../operators/reduce_ops/reduce_mean_op.cu | 6 +++ .../operators/reduce_ops/reduce_sum_op.cu | 12 +++++ .../operators/sequence_ops/sequence_mask_op.h | 4 +- .../sequence_ops/sequence_reverse_op.h | 4 +- .../sequence_softmax_cudnn_op.cu.cc | 9 ++++ .../sequence_ops/sequence_softmax_op.cc | 4 +- .../sequence_ops/sequence_softmax_op.cu | 29 ++++++++++- paddle/fluid/operators/trace_op.cu | 6 +++ 17 files changed, 163 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index ec48a51baa2..e651f19fedb 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -30,7 +30,7 @@ endforeach() register_operators(EXCLUDES gen_nccl_id_op DEPS ${DISTRIBUTE_DEPS}) -if(WITH_NCCL) +if(WITH_NCCL OR WITH_RCCL) set(DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} nccl_common) endif() diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.h b/paddle/fluid/operators/distributed_ops/allreduce_op.h index e486faa5758..157924f0854 100644 --- a/paddle/fluid/operators/distributed_ops/allreduce_op.h +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.h @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -36,7 +36,7 @@ class AllReduceOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(is_gpu_place(place), true, platform::errors::PreconditionNotMet( "AllReduce op can run on gpu place only for now.")); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto& dev_ctx = ctx.template device_context(); auto in = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -73,7 +73,11 @@ class AllReduceOpKernel : public framework::OpKernel { sendbuff, recvbuff, numel, static_cast(dtype), red_type, comm, stream)); if (ctx.Attr("sync_mode")) { +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif } #else PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc index 337422f0bd6..1bfcc8af03e 100644 --- a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc +++ b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -39,7 +39,7 @@ class NCCLBroadcastOpKernel : public framework::OpKernel { platform::errors::PreconditionNotMet( "The place of ExecutionContext should be CUDAPlace.")); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) int dev_id = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).device; int root_dev_id = ctx.Attr("root"); @@ -68,7 +68,11 @@ class NCCLBroadcastOpKernel : public framework::OpKernel { << " From " << root_dev_id << " to " << dev_id; if (ctx.Attr("sync_mode")) { +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif } #else PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h index d8639627c3e..c8c437c4965 100644 --- a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h +++ b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h @@ -30,7 +30,7 @@ class RefByTrainerIdKernel : public framework::OpKernel { int64_t trainer_id = 0; auto* trainer_id_data = trainer_id_t->data(); if (platform::is_gpu_place(context.GetPlace())) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto stream = context.cuda_device_context().stream(); memory::Copy<>(platform::CPUPlace(), &trainer_id, BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index 2af3716ae43..e74f537c852 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" -#if __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "thrust/device_vector.h" #endif @@ -87,7 +87,7 @@ struct KronOpFunctor { const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr, *p_stride_out = nullptr, *p_shape_y = nullptr; -#if __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) thrust::device_vector d_stride_x(ndims); thrust::device_vector d_stride_y(ndims); thrust::device_vector d_stride_out(ndims); @@ -326,7 +326,7 @@ struct KronGradOpFunctor { const int64_t* p_stride_y = nullptr; const int64_t* p_stride_dout = nullptr; const int64_t* p_shape_y = nullptr; -#if __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) thrust::device_vector d_stride_x(ndims); thrust::device_vector d_stride_y(ndims); thrust::device_vector d_stride_dout(ndims); @@ -369,7 +369,19 @@ struct KronGradOpFunctor { for_range(func); // reduce_sum along aixs 1 -#if __NVCC__ +#ifdef __HIPCC__ + auto stream = dev_ctx.stream(); // it is a cuda device_context + if (dx) { + TensorReduce>( + dout_x, dx, {1}, static_cast(0), hipcub::Sum(), + IdentityFunctor(), stream); + } + if (dy) { + TensorReduce>( + dout_y, dy, {1}, static_cast(0), hipcub::Sum(), + IdentityFunctor(), stream); + } +#elif defined(__NVCC__) auto stream = dev_ctx.stream(); // it is a cuda device_context if (dx) { TensorReduce>( diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index b6eac7bf0cc..f93a87831f1 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #endif @@ -45,7 +45,12 @@ template void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, const std::vector& reduce_dims, const paddle::framework::ExecutionContext& ctx) { -#ifdef __NVCC__ +#ifdef __HIPCC__ + auto stream = ctx.cuda_device_context().stream(); + TensorReduce>( + *input, output, reduce_dims, static_cast(0), hipcub::Sum(), + IdentityFunctor(), stream); +#elif defined(__NVCC__) auto stream = ctx.cuda_device_context().stream(); TensorReduce>( *input, output, reduce_dims, static_cast(0), cub::Sum(), diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu index 2f61c53f877..52ce37878c2 100644 --- a/paddle/fluid/operators/prelu_op.cu +++ b/paddle/fluid/operators/prelu_op.cu @@ -95,7 +95,7 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr, template class PreluOpGradFunctor { public: - void operator()(cudaStream_t stream, const T* x, const T* alpha, const T* dy, + void operator()(gpuStream_t stream, const T* x, const T* alpha, const T* dy, T* dx, T* dalpha, const framework::DDim& input_dims, PRELU_MODE mode) { size_t numel = 1; @@ -174,9 +174,15 @@ class CUDAPReluGradKernel : public framework::OpKernel { reduce_dims.push_back(i); } +#ifdef __HIPCC__ + TensorReduce>( + dalpha_tmp, dalpha, reduce_dims, static_cast(0), hipcub::Sum(), + IdentityFunctor(), stream); +#else TensorReduce>( dalpha_tmp, dalpha, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); +#endif } }; diff --git a/paddle/fluid/operators/reduce_ops/CMakeLists.txt b/paddle/fluid/operators/reduce_ops/CMakeLists.txt index c32301e5e08..92107c9dc44 100644 --- a/paddle/fluid/operators/reduce_ops/CMakeLists.txt +++ b/paddle/fluid/operators/reduce_ops/CMakeLists.txt @@ -13,7 +13,7 @@ else() register_operators() endif() -if(WITH_GPU) +if(WITH_GPU OR WITH_ROCM) file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.part.cu") string(REPLACE ".part.cu" "" OPS "${OPS}") @@ -38,3 +38,7 @@ if(WITH_GPU) nv_test(check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor) endif() endif() + +if(WITH_ROCM) + hip_test(check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor) +endif() diff --git a/paddle/fluid/operators/reduce_ops/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h index 49bcbf3abb1..dad7c848a6c 100644 --- a/paddle/fluid/operators/reduce_ops/cub_reduce.h +++ b/paddle/fluid/operators/reduce_ops/cub_reduce.h @@ -20,7 +20,14 @@ #include #include -#include // NOLINT +#ifdef __NVCC__ +#include "cub/cub.cuh" // NOLINT +#endif + +#ifdef __HIPCC__ +#include +#endif + #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" @@ -64,7 +71,12 @@ template ::TempStorage temp_storage; +#else __shared__ typename cub::BlockReduce::TempStorage temp_storage; +#endif int idx_x = blockIdx.x * reduce_num; int idx_y = threadIdx.x; Ty reduce_var = init; @@ -73,8 +85,13 @@ __global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer, reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); __syncthreads(); +#ifdef __HIPCC__ + reduce_var = hipcub::BlockReduce(temp_storage) + .Reduce(reduce_var, reducer); +#else reduce_var = cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); +#endif if (threadIdx.x == 0) { y[blockIdx.x] = reduce_var; @@ -90,7 +107,12 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, Array reduce_strides, Array left_dim, Array left_strides) { +#ifdef __HIPCC__ + __shared__ + typename hipcub::BlockReduce::TempStorage temp_storage; +#else __shared__ typename cub::BlockReduce::TempStorage temp_storage; +#endif Array sub_index; int left_idx = blockIdx.x; for (int i = 0; i < Rank - ReduceRank; ++i) { @@ -122,8 +144,13 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, } __syncthreads(); +#ifdef __HIPCC__ + reduce_var = hipcub::BlockReduce(temp_storage) + .Reduce(reduce_var, reducer); +#else reduce_var = cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); +#endif if (threadIdx.x == 0) { y[blockIdx.x] = reduce_var; @@ -188,7 +215,7 @@ static void TensorReduceImpl( int left_num, int reduce_num, const std::vector& x_strides, const std::vector& reduce_dim, const std::vector& reduce_strides, const std::vector& left_dim, const std::vector& left_strides, - cudaStream_t stream) { + gpuStream_t stream) { #define CUB_RANK_CASE(i, ...) \ case i: { \ constexpr auto kRank = i; \ @@ -211,17 +238,32 @@ static void TensorReduceImpl( int rank = x_strides.size(); int reduce_rank = reduce_strides.size(); if (rank == reduce_rank) { +#ifdef __HIPCC__ + hipcub::TransformInputIterator trans_x( + x_data, transformer); +#else cub::TransformInputIterator trans_x( x_data, transformer); +#endif size_t temp_storage_bytes = 0; +#ifdef __HIPCC__ + hipcub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, + reduce_num, reducer, init, stream); +#else cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, reduce_num, reducer, init, stream); +#endif framework::Tensor tmp; auto* temp_storage = tmp.mutable_data( framework::make_ddim({static_cast(temp_storage_bytes)}), place); +#ifdef __HIPCC__ + hipcub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, + y_data, reduce_num, reducer, init, stream); +#else cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, reduce_num, reducer, init, stream); +#endif return; } if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { @@ -280,7 +322,7 @@ template void TensorReduce(const framework::Tensor& x, framework::Tensor* y, std::vector origin_reduce_dims, const Ty& init, const ReduceOp& reducer, const TransformOp& transformer, - cudaStream_t stream) { + gpuStream_t stream) { auto x_dim = framework::vectorize(x.dims()); std::vector new_x_dim, new_reduce_dims; int is_reduced = 0; @@ -362,11 +404,11 @@ struct TensorReduceFunctor { const double& init; const ReduceOp& reducer; const TransformOp& transformer; - cudaStream_t stream; + gpuStream_t stream; TensorReduceFunctor(const framework::Tensor& x, framework::Tensor* y, std::vector origin_reduce_dims, const double& init, const ReduceOp& reducer, const TransformOp& transformer, - cudaStream_t stream) + gpuStream_t stream) : x(x), y(y), origin_reduce_dims(origin_reduce_dims), diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu index cc3653fcb43..d4d4e04f0cb 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu @@ -56,9 +56,15 @@ class ReduceMeanKernel : public framework::OpKernel { } auto stream = context.cuda_device_context().stream(); +#ifdef PADDLE_WITH_HIP + TensorReduce>( + *input, output, reduce_dims, static_cast(0), hipcub::Sum(), + DivideFunctor(reduce_num), stream); +#else TensorReduce>( *input, output, reduce_dims, static_cast(0), cub::Sum(), DivideFunctor(reduce_num), stream); +#endif } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu index 219cc231a1e..495e4c180a0 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu @@ -56,13 +56,25 @@ class ReduceSumKernel : public framework::OpKernel { if (out_dtype >= 0) { framework::VisitDataTypeSmall( static_cast(out_dtype), +#ifdef __HIPCC__ + TensorReduceFunctor>( + *input, output, reduce_dims, static_cast(0.0), + hipcub::Sum(), IdentityFunctor(), stream)); +#else TensorReduceFunctor>( *input, output, reduce_dims, static_cast(0.0), cub::Sum(), IdentityFunctor(), stream)); +#endif } else { +#ifdef __HIPCC__ + TensorReduce>( + *input, output, reduce_dims, static_cast(0), hipcub::Sum(), + IdentityFunctor(), stream); +#else TensorReduce>( *input, output, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); +#endif } } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h index 3abaeccb283..2ce0b02d437 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h @@ -14,7 +14,7 @@ #pragma once -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) #include #include #include @@ -107,7 +107,7 @@ class SequenceMaskKernel : public framework::OpKernel { auto *x_data = x->data(); auto x_numel = x->numel(); if (maxlen < 0) { -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) VLOG(10) << "SequenceMaskOp on GPU may be slow when maxlen is not provided."; maxlen = static_cast( diff --git a/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h index c84028bd63a..2094572a78a 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h @@ -130,13 +130,13 @@ class SequenceReverseOpKernel : public framework::OpKernel { const size_t *lod; size_t lod_count = x.lod()[0].size(); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { lod = x.lod()[0].CUDAData(ctx.GetPlace()); } else { #endif lod = x.lod()[0].data(); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } #endif diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc index b33d87e644f..46e4196585b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_cudnn_op.cu.cc @@ -104,9 +104,18 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; + +#ifdef PADDLE_WITH_HIP +// MIOPEN not support float64 +REGISTER_OP_KERNEL(sequence_softmax, CUDNN, ::paddle::platform::CUDAPlace, + ops::SequenceSoftmaxCUDNNKernel); +REGISTER_OP_KERNEL(sequence_softmax_grad, CUDNN, ::paddle::platform::CUDAPlace, + ops::SequenceSoftmaxGradCUDNNKernel); +#else REGISTER_OP_KERNEL(sequence_softmax, CUDNN, ::paddle::platform::CUDAPlace, ops::SequenceSoftmaxCUDNNKernel, ops::SequenceSoftmaxCUDNNKernel); REGISTER_OP_KERNEL(sequence_softmax_grad, CUDNN, ::paddle::platform::CUDAPlace, ops::SequenceSoftmaxGradCUDNNKernel, ops::SequenceSoftmaxGradCUDNNKernel); +#endif diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 992a0b458b1..9a7bb67bdfc 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -36,7 +36,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { // choose cudnn kernel if the runtime supported. bool use_cudnn = ctx.Attr("use_cudnn"); bool runtime_cudnn_support = false; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto& dev_ctx = ctx.template device_context(); @@ -132,7 +132,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { // choose cudnn kernel if the runtime supported. bool use_cudnn = ctx.Attr("use_cudnn"); bool runtime_cudnn_support = false; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto& dev_ctx = ctx.template device_context(); diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu index 58022c076cf..0c23533aaaa 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cu @@ -13,7 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include // NOLINT + +#ifdef __NVCC__ +#include +#endif + +#ifdef __HIPCC__ +#include +#endif + #include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h" @@ -23,7 +31,11 @@ namespace operators { using LoDTensor = framework::LoDTensor; template +#ifdef __HIPCC__ +using BlockReduce = hipcub::BlockReduce; +#else using BlockReduce = cub::BlockReduce; +#endif template using BlockReduceTempStorage = typename BlockReduce::TempStorage; @@ -45,8 +57,13 @@ __global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod, T ele = in_data[start + tid]; max_ele = max_ele > ele ? max_ele : ele; } +#ifdef __HIPCC__ + max_ele = + BlockReduce(temp_storage).Reduce(max_ele, hipcub::Max()); +#else max_ele = BlockReduce(temp_storage).Reduce(max_ele, cub::Max()); +#endif if (threadIdx.x == 0) { shared_max_data = max_ele; } @@ -58,8 +75,13 @@ __global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod, T ele = in_data[start + tid]; sum_data += real_exp(ele - shared_max_data); } +#ifdef __HIPCC__ + sum_data = + BlockReduce(temp_storage).Reduce(sum_data, hipcub::Sum()); +#else sum_data = BlockReduce(temp_storage).Reduce(sum_data, cub::Sum()); +#endif if (threadIdx.x == 0) { shared_sum_data = sum_data; } @@ -94,7 +116,12 @@ __global__ void sequence_softmax_grad_kernel(const T *softmax_grad_data, T s_d = softmax_data[idx]; result += s_g_d * s_d; } +#ifdef __HIPCC__ + result = + BlockReduce(temp_storage).Reduce(result, hipcub::Sum()); +#else result = BlockReduce(temp_storage).Reduce(result, cub::Sum()); +#endif if (threadIdx.x == 0) { shared_data = result; } diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu index ea328361ded..a2d51e9c5bd 100644 --- a/paddle/fluid/operators/trace_op.cu +++ b/paddle/fluid/operators/trace_op.cu @@ -43,9 +43,15 @@ class TraceCUDAKernel : public framework::OpKernel { auto stream = context.cuda_device_context().stream(); std::vector reduce_dims; reduce_dims.push_back(out->dims().size()); +#ifdef __HIPCC__ + TensorReduce>( + diag, out, reduce_dims, static_cast(0), hipcub::Sum(), + IdentityFunctor(), stream); +#else TensorReduce>( diag, out, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); +#endif } } }; -- GitLab