diff --git a/paddle/fluid/operators/split_lod_tensor_op.cc b/paddle/fluid/operators/split_lod_tensor_op.cc index 4adbbacc844c64329c7c62f8969cdc3f42936beb..fe646b2830b66fb6f4ed8ffe14614d9bf7d9aa48 100644 --- a/paddle/fluid/operators/split_lod_tensor_op.cc +++ b/paddle/fluid/operators/split_lod_tensor_op.cc @@ -65,7 +65,7 @@ class SplitLoDTensorOp : public framework::OperatorBase { if (platform::is_cpu_place(mask.place())) { cpu_mask->ShareDataWith(mask); } else if (platform::is_gpu_place(mask.place())) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) framework::TensorCopy(mask, platform::CPUPlace(), dev_ctx, cpu_mask.get()); #else diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu b/paddle/fluid/operators/sync_batch_norm_op.cu index 26fbe39a3c3691b6ce6414d75f6da216bc888017..1c9e732b194ad70b43b2e00b1ae4b451a0662974 100644 --- a/paddle/fluid/operators/sync_batch_norm_op.cu +++ b/paddle/fluid/operators/sync_batch_norm_op.cu @@ -91,6 +91,16 @@ class SyncBatchNormGradKernel namespace ops = paddle::operators; namespace plat = paddle::platform; +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +REGISTER_OP_CUDA_KERNEL( + sync_batch_norm, ops::SyncBatchNormKernel, + ops::SyncBatchNormKernel); +REGISTER_OP_CUDA_KERNEL( + sync_batch_norm_grad, + ops::SyncBatchNormGradKernel, + ops::SyncBatchNormGradKernel); +#else REGISTER_OP_CUDA_KERNEL( sync_batch_norm, ops::SyncBatchNormKernel, ops::SyncBatchNormKernel, @@ -100,5 +110,6 @@ REGISTER_OP_CUDA_KERNEL( ops::SyncBatchNormGradKernel, ops::SyncBatchNormGradKernel, ops::SyncBatchNormGradKernel); +#endif // clang-format on diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu.h b/paddle/fluid/operators/sync_batch_norm_op.cu.h index d52eaecb94c12defbcf1d850a16747236611e8f4..d08a34ade77f289810a181d611ef8d2801c1d56d 100644 --- a/paddle/fluid/operators/sync_batch_norm_op.cu.h +++ b/paddle/fluid/operators/sync_batch_norm_op.cu.h @@ -19,12 +19,19 @@ limitations under the License. */ #include #include #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#include "paddle/fluid/platform/miopen_helper.h" +#endif #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/norm_utils.h" -#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/nccl_helper.h" @@ -186,7 +193,7 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx, auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0); -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto *comm = dev_ctx.nccl_comm(); if (comm) { int dtype = platform::ToNCCLDataType(mean_out->type()); @@ -460,7 +467,7 @@ void SyncBatchNormGradFunctor( dy_d, x_d, saved_mean, N, fsize, C, stats); } -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto *comm = dev_ctx.nccl_comm(); if (comm) { int dtype = platform::ToNCCLDataType(scale->type()); diff --git a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h index ce94ba1ce9e8c10e04cd3fe5fee214ef30ffb918..2d7fed2987f4b7425e599e4e76c51926fb5b3e1b 100644 --- a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h +++ b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h @@ -91,7 +91,7 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, int64_t limit = x.numel(); -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) if (platform::is_gpu_place(place)) { auto &cuda_dev_ctx = dynamic_cast(dev_ctx); functor(cuda_dev_ctx, &x, out, &ddx, &ddout, dout, dx); @@ -105,7 +105,7 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, platform::ForRange for_range(cpu_dev_ctx, limit); for_range(actual_functor); -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) } #endif diff --git a/paddle/fluid/operators/top_k_function_cuda.h b/paddle/fluid/operators/top_k_function_cuda.h index 41df6f488f1925273ac85c560b0639d162b3ac73..a7d7ea260ecdf44ab94e65f28db1294f7c57c527 100644 --- a/paddle/fluid/operators/top_k_function_cuda.h +++ b/paddle/fluid/operators/top_k_function_cuda.h @@ -16,11 +16,26 @@ limitations under the License. */ #include #include #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +#endif #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/float16.h" +#ifdef __HIPCC__ +namespace rocprim { +namespace detail { +template <> +struct radix_key_codec_base + : radix_key_codec_integral {}; +} // namespace detail +} // namespace rocprim +namespace cub = hipcub; +#else // set cub base traits in order to handle float16 namespace cub { template <> @@ -28,6 +43,7 @@ struct NumericTraits : BaseTraits {}; } // namespace cub +#endif namespace paddle { namespace operators { @@ -439,6 +455,16 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, input_indices.data(), sorted_indices_ptr, num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, cu_stream); +#ifdef __HIPCC__ + if (err != hipSuccess) { + LOG(ERROR) << "TopKOP failed as could not launch " + "hipcub::DeviceSegmentedRadixSort::SortPairsDescending to " + "calculate " + "temp_storage_bytes, status: " + << hipGetErrorString(err); + return false; + } +#else if (err != cudaSuccess) { LOG(ERROR) << "TopKOP failed as could not launch " @@ -447,12 +473,22 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, << cudaGetErrorString(err); return false; } +#endif } else { auto err = cub::DeviceSegmentedRadixSort::SortPairs( nullptr, temp_storage_bytes, input, sorted_values_ptr, input_indices.data(), sorted_indices_ptr, num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, cu_stream); +#ifdef __HIPCC__ + if (err != hipSuccess) { + LOG(ERROR) << "TopKOP failed as could not launch " + "hipcub::DeviceSegmentedRadixSort::SortPairs to calculate " + "temp_storage_bytes, status: " + << hipGetErrorString(err); + return false; + } +#else if (err != cudaSuccess) { LOG(ERROR) << "TopKOP failed as could not launch " "cub::DeviceSegmentedRadixSort::SortPairs to calculate " @@ -460,6 +496,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, << cudaGetErrorString(err); return false; } +#endif } Tensor temp_storage; temp_storage.mutable_data(ctx.GetPlace(), temp_storage_bytes); @@ -470,6 +507,17 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, sorted_values_ptr, input_indices.data(), sorted_indices_ptr, num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, cu_stream); +#ifdef __HIPCC__ + if (err != hipSuccess) { + LOG(ERROR) << "TopKOP failed as could not launch " + "hipcub::DeviceSegmentedRadixSort::SortPairsDescending to " + "sort input, " + "temp_storage_bytes: " + << temp_storage_bytes + << ", status: " << hipGetErrorString(err); + return false; + } +#else if (err != cudaSuccess) { LOG(ERROR) << "TopKOP failed as could not launch " "cub::DeviceSegmentedRadixSort::SortPairsDescending to " @@ -479,12 +527,24 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, << ", status: " << cudaGetErrorString(err); return false; } +#endif } else { auto err = cub::DeviceSegmentedRadixSort::SortPairs( temp_storage.data(), temp_storage_bytes, input, sorted_values_ptr, input_indices.data(), sorted_indices_ptr, num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, cu_stream); +#ifdef __HIPCC__ + if (err != hipSuccess) { + LOG(ERROR) << "TopKOP failed as could not launch " + "hipcub::DeviceSegmentedRadixSort::SortPairs to " + "sort input, " + "temp_storage_bytes: " + << temp_storage_bytes + << ", status: " << hipGetErrorString(err); + return false; + } +#else if (err != cudaSuccess) { LOG(ERROR) << "TopKOP failed as could not launch " "cub::DeviceSegmentedRadixSort::SortPairs to " @@ -494,6 +554,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, << ", status: " << cudaGetErrorString(err); return false; } +#endif } auto& dev = *ctx.eigen_device(); if (k < num_cols) { diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 39a56f874d95029017f35e46792edd0935bb35cf..498f51d53adc724319a86d683fb6e913006856fb 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -15,7 +15,12 @@ limitations under the License. */ #pragma once #include #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +#endif #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_function_cuda.h" #include "paddle/fluid/operators/top_k_op.h" diff --git a/paddle/fluid/operators/trace_op.h b/paddle/fluid/operators/trace_op.h index 54c4251a38cf10a8f489ca78346fae9471b464db..b7a6e559ed4ef6ee4cd43b9375b3531488db449d 100644 --- a/paddle/fluid/operators/trace_op.h +++ b/paddle/fluid/operators/trace_op.h @@ -145,7 +145,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, int64_t pos = std::abs(offset) * offset_stride; int64_t dim_size = ret_strides.size(); -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) thrust::device_vector diag_vec(vectorize(dig_stride)); const int64_t* diag_arr = thrust::raw_pointer_cast(diag_vec.data()); thrust::device_vector ret_vec(ret_strides); @@ -238,7 +238,7 @@ class TraceGradKernel : public framework::OpKernel { int64_t diag_size = len2 < len1 ? len2 : len1; int64_t pos = std::abs(offset) * offset_stride; if (diag_size > 0) { -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) thrust::device_vector output_vec(vectorize(output_stride)); const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data()); thrust::device_vector input_vec(vectorize(input_stride)); diff --git a/paddle/fluid/operators/unique_op.cu b/paddle/fluid/operators/unique_op.cu index d22406f27c4702ddd5e05daf802f3baeb26a02d6..87a46e11d9f91b4622a210b007142ec09bdbc9ce 100644 --- a/paddle/fluid/operators/unique_op.cu +++ b/paddle/fluid/operators/unique_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include +#include #include #include #include diff --git a/paddle/fluid/operators/unstack_op.h b/paddle/fluid/operators/unstack_op.h index 6344ea16f81cddb1c8f4f07f28fd318f40296427..82118b692707fbb1459980b27573e770c978d521 100644 --- a/paddle/fluid/operators/unstack_op.h +++ b/paddle/fluid/operators/unstack_op.h @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) #include #include "paddle/fluid/framework/array.h" #endif @@ -103,7 +103,7 @@ class UnStackGradKernel : public framework::OpKernel { for (auto i = 0; i < axis; ++i) pre *= dim[i]; for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) int total_num = pre * n * post; auto &dev_ctx = ctx.template device_context(); @@ -156,14 +156,14 @@ class UnStackKernel : public framework::OpKernel { int post = total_num / (n * pre); auto &dev_ctx = ctx.template device_context(); -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) thrust::device_vector device_dx_vec(dx_datas); auto dx_data_arr = device_dx_vec.data().get(); #else auto dx_data_arr = dx_datas.data(); #endif StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post); -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) // Wait() must be called because device_dx_vec may be destructed before // kernel ends dev_ctx.Wait(); diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index f043b0179491979f2dc1ae35da5b99d2800d8764..f38f5d9f7235795d5460d6a436c6c7dc397e52da 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -16,6 +16,9 @@ limitations under the License. */ #include +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" #endif diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 8b9276d4fa03f51e18b93e538aa944e8b719dc86..7451cac63d0cea3ea9e942427f7f641de602b5fa 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -159,6 +159,7 @@ class WarpCTCFunctor { warpctc_version_ = platform::dynload::get_warpctc_version(); if (platform::is_gpu_place(ctx.GetPlace())) { +// HIP not support ctcOptions in third-party warpctc #ifdef PADDLE_WITH_CUDA options_.loc = CTC_GPU; options_.stream = reinterpret_cast( diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index bfefeb2f4a3da5da4e7c2059decb1cf677f02a1e..30c38236c5244984c75eee9eb88fb452410a20ac 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -108,7 +108,11 @@ class CublasHandleHolder { } #endif +#ifdef PADDLE_WITH_HIP + const rocblas_handle& GetCublasHandle() const { return handle_; } +#else const cublasHandle_t& GetCublasHandle() const { return handle_; } +#endif ~CublasHandleHolder() PADDLE_MAY_THROW { #ifdef PADDLE_WITH_HIP diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 98dcf72aa4fb48709976aefa3c33cf518ba76fac..22daaf101cf200891ee98fd2c5c12b944d76825c 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -459,9 +459,15 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return context()->CudnnHandle(); } +#ifdef PADDLE_WITH_HIP +rocblas_handle CUDADeviceContext::cublas_handle() const { + return context()->CublasHandle()->GetCublasHandle(); +} +#else cublasHandle_t CUDADeviceContext::cublas_handle() const { return context()->CublasHandle()->GetCublasHandle(); } +#endif CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 11123c4e658ed9891336096b881a78d527dfd1c5..411fe09c864aa2899ada083d25374cd7c7c247b9 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -409,8 +409,12 @@ class CUDADeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle() const; #endif - /*! \brief Return cublas handle in the device context. */ +/*! \brief Return cublas handle in the device context. */ +#ifdef PADDLE_WITH_HIP + rocblas_handle cublas_handle() const; +#else cublasHandle_t cublas_handle() const; +#endif /*! \brief Return a cudnn workspace handle to call multiple cudnn * functions without interrupting by other threads. diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 3e9fe461d746ca800a8ddd3f8aa12776b12479d6..2f9413c4f3ea7e72e8c3690d985329065570c223 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -47,7 +47,11 @@ TEST(Device, CUDADeviceContext) { cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); #endif ASSERT_NE(nullptr, cudnn_handle); +#ifdef PADDLE_WITH_HIP + rocblas_handle cublas_handle = device_context->cublas_handle(); +#else cublasHandle_t cublas_handle = device_context->cublas_handle(); +#endif ASSERT_NE(nullptr, cublas_handle); delete device_context; } diff --git a/paddle/fluid/platform/miopen_desc.h b/paddle/fluid/platform/miopen_desc.h index 68db32bac103bda29a89a7b74765f549b34b5519..7de713559ae410650f4722c89f39c9d803232e60 100644 --- a/paddle/fluid/platform/miopen_desc.h +++ b/paddle/fluid/platform/miopen_desc.h @@ -37,9 +37,9 @@ namespace platform { using framework::Tensor; template -inline miopenDataType_t ToMIOpenDataType(const T& t) { +inline miopenDataType_t ToCudnnDataType(const T& t) { auto type = framework::ToDataType(t); - return ToMIOpenDataType(type); + return ToCudnnDataType(type); } inline std::vector TransformDimOrder(const std::vector& dims) { @@ -66,7 +66,7 @@ inline std::vector TransformDimOrder(const std::vector& dims) { } template <> -inline miopenDataType_t ToMIOpenDataType( +inline miopenDataType_t ToCudnnDataType( const framework::proto::VarType::Type& t) { miopenDataType_t type = miopenFloat; switch (t) { @@ -84,37 +84,54 @@ inline miopenDataType_t ToMIOpenDataType( class ActivationDescriptor { public: + using T = miopenActivationDescriptor; + struct Deleter { + void operator()(T* t) { + if (t != nullptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::miopenDestroyActivationDescriptor(t)); + t = nullptr; + } + } + }; ActivationDescriptor() { + T* raw_ptr; PADDLE_ENFORCE_CUDA_SUCCESS( - dynload::miopenCreateActivationDescriptor(&desc_)); - } - ~ActivationDescriptor() { - PADDLE_ENFORCE_CUDA_SUCCESS( - dynload::miopenDestroyActivationDescriptor(desc_)); + dynload::miopenCreateActivationDescriptor(&raw_ptr)); + desc_.reset(raw_ptr); } template void set(miopenActivationMode_t mode, const T& coef) { PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetActivationDescriptor( - desc_, mode, static_cast(coef), 0.0, 0.0)); + desc_.get(), mode, static_cast(coef), 0.0, 0.0)); } - miopenActivationDescriptor_t desc() { return desc_; } - miopenActivationDescriptor_t desc() const { return desc_; } + T* desc() { return desc_.get(); } + T* desc() const { return desc_.get(); } private: - miopenActivationDescriptor_t desc_; + std::unique_ptr desc_; }; class TensorDescriptor { public: + using T = miopenTensorDescriptor; + struct Deleter { + void operator()(T* t) { + if (t != nullptr) { + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(t)); + t = nullptr; + } + } + }; TensorDescriptor() { - PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreateTensorDescriptor(&desc_)); - } - ~TensorDescriptor() { - PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(desc_)); + T* raw_ptr; + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::miopenCreateTensorDescriptor(&raw_ptr)); + desc_.reset(raw_ptr); } - miopenTensorDescriptor_t desc() { return desc_; } - miopenTensorDescriptor_t desc() const { return desc_; } + T* desc() { return desc_.get(); } + T* desc() const { return desc_.get(); } void set(const Tensor& tensor, const int groups = 1) { auto dims = framework::vectorize(tensor.dims()); @@ -128,7 +145,7 @@ class TensorDescriptor { dims_with_group[1] = dims_with_group[1] / groups; } PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor( - desc_, ToMIOpenDataType(tensor.type()), + (miopenTensorDescriptor_t)(desc_.get()), ToCudnnDataType(tensor.type()), static_cast(dims_with_group.size()), const_cast(dims_with_group.data()), const_cast(strides.data()))); @@ -136,6 +153,9 @@ class TensorDescriptor { void set(const Tensor& tensor, const miopenTensorFormat_t format) { const int groups = 1; + PADDLE_ENFORCE_EQ(format, MIOPEN_TENSOR_NCHW, + platform::errors::InvalidArgument( + "format should ONLY be NCHW in MIOPEN.")); auto dims = framework::vectorize(tensor.dims()); std::vector strides(dims.size()); strides[dims.size() - 1] = 1; @@ -147,26 +167,35 @@ class TensorDescriptor { dims_with_group[1] = dims_with_group[1] / groups; } PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor( - desc_, ToMIOpenDataType(tensor.type()), + (miopenTensorDescriptor_t)(desc_.get()), ToCudnnDataType(tensor.type()), static_cast(dims_with_group.size()), const_cast(dims_with_group.data()), const_cast(strides.data()))); } private: - miopenTensorDescriptor_t desc_; + std::unique_ptr desc_; }; class FilterDescriptor { public: + using T = miopenTensorDescriptor; + struct Deleter { + void operator()(T* t) { + if (t != nullptr) { + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(t)); + t = nullptr; + } + } + }; FilterDescriptor() { - PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreateTensorDescriptor(&desc_)); - } - ~FilterDescriptor() { - PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(desc_)); + T* raw_ptr; + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::miopenCreateTensorDescriptor(&raw_ptr)); + desc_.reset(raw_ptr); } - miopenTensorDescriptor_t desc() { return desc_; } - miopenTensorDescriptor_t desc() const { return desc_; } + T* desc() { return desc_.get(); } + T* desc() const { return desc_.get(); } void set(const Tensor& tensor, const miopenTensorFormat_t format, const int groups = 1) { @@ -176,45 +205,55 @@ class FilterDescriptor { platform::errors::InvalidArgument( "format should ONLY be NCHW in MIOPEN.")); transformed_dims = dims; - if (groups > 1) { - transformed_dims[1] = transformed_dims[1] / groups; - } - PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor( - desc_, ToMIOpenDataType(tensor.type()), - static_cast(transformed_dims.size()), - const_cast(transformed_dims.data()), nullptr)); + // if (groups > 1) { + // transformed_dims[1] = transformed_dims[1] / groups; + // } + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSet4dTensorDescriptor( + (miopenTensorDescriptor_t)desc_.get(), ToCudnnDataType(tensor.type()), + transformed_dims[0], transformed_dims[1], transformed_dims[2], + transformed_dims[3])); } private: - miopenTensorDescriptor_t desc_; + std::unique_ptr desc_; }; class ConvolutionDescriptor { public: + using T = miopenConvolutionDescriptor; + struct Deleter { + void operator()(T* t) { + if (t != nullptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::miopenDestroyConvolutionDescriptor(t)); + t = nullptr; + } + } + }; ConvolutionDescriptor() { + T* raw_ptr; PADDLE_ENFORCE_CUDA_SUCCESS( - dynload::miopenCreateConvolutionDescriptor(&desc_)); - } - ~ConvolutionDescriptor() { - PADDLE_ENFORCE_CUDA_SUCCESS( - dynload::miopenDestroyConvolutionDescriptor(desc_)); + dynload::miopenCreateConvolutionDescriptor(&raw_ptr)); + desc_.reset(raw_ptr); } - miopenConvolutionDescriptor_t desc() { return desc_; } - miopenConvolutionDescriptor_t desc() const { return desc_; } + T* desc() { return desc_.get(); } + T* desc() const { return desc_.get(); } void set(miopenDataType_t dtype, const std::vector& pads, const std::vector& strides, const std::vector& dilations, bool allow_tf32, const int groups = 1) { PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenInitConvolutionNdDescriptor( - desc_, static_cast(pads.size()), const_cast(pads.data()), + (miopenConvolutionDescriptor_t)desc_.get(), + static_cast(pads.size()), const_cast(pads.data()), const_cast(strides.data()), const_cast(dilations.data()), miopenConvolution)); PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::miopenSetConvolutionGroupCount(desc_, groups)); + platform::dynload::miopenSetConvolutionGroupCount( + (miopenConvolutionDescriptor_t)desc_.get(), groups)); } private: - miopenConvolutionDescriptor_t desc_; + std::unique_ptr desc_; }; } // namespace platform diff --git a/paddle/fluid/platform/miopen_helper.h b/paddle/fluid/platform/miopen_helper.h index f6045130851ee9d5dd07814956745495f52d2cab..435d28d518df1adf3ed37841324100e0bfbffa88 100644 --- a/paddle/fluid/platform/miopen_helper.h +++ b/paddle/fluid/platform/miopen_helper.h @@ -43,23 +43,6 @@ typedef enum { MIOPEN_TENSOR_NHWC = 1, } miopenTensorFormat_t; -// MIOPEN do not support indirect function call defined in cudnnWorkspaceHandle -struct miopenWorkspace { - explicit miopenWorkspace(size_t size) : size(size), data(NULL) { - PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&data, size)); - } - miopenWorkspace(const miopenWorkspace&) = delete; - miopenWorkspace(miopenWorkspace&&) = default; - miopenWorkspace& operator=(miopenWorkspace&&) = default; - ~miopenWorkspace() { - if (data) { - hipFree(data); - } - } - size_t size; - void* data; -}; - inline const char* miopenGetErrorString(miopenStatus_t status) { switch (status) { case miopenStatusSuccess: diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 21088e06a23afbf6e31bffcbd41372686d9650c3..58ef177863093dd75a3d700f7ba0079573365707 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -984,7 +984,7 @@ void BindImperative(py::module *m_ptr) { PADDLE_THROW(platform::errors::Unimplemented( "Imperative allreduce is not supported when paddle is " "not compiled with NCCL.")); -#endif // PADDLE_WITH_NCCL +#endif // PADDLE_WITH_NCCL or PADDLE_WITH_RCCL } }, py::call_guard()) @@ -1435,7 +1435,7 @@ void BindImperative(py::module *m_ptr) { py::call_guard()); #endif -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) py::class_>( m, "NCCLParallelContext") diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d11f3c005eed50e13a28a7a26178a5077f141c55..2e5cd3473c3f6dfe3c3de38fde003e6bec6adab7 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1125,7 +1125,7 @@ All parameter, weight, gradient are variables in Paddle. .def("get_fetch_list", [](Variable &self) { return self.GetMutable(); }, py::return_value_policy::reference) -#if (defined(PADDLE_WITH_NCCL)) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) .def("get_communicator", [](Variable &self) -> platform::Communicator * { return self.GetMutable();