未验证 提交 4d647ec1 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] update fluid platform for rocm (part5), test=develop (#31315)

上级 522c91ec
...@@ -65,7 +65,7 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -65,7 +65,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
if (platform::is_cpu_place(mask.place())) { if (platform::is_cpu_place(mask.place())) {
cpu_mask->ShareDataWith(mask); cpu_mask->ShareDataWith(mask);
} else if (platform::is_gpu_place(mask.place())) { } else if (platform::is_gpu_place(mask.place())) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
framework::TensorCopy(mask, platform::CPUPlace(), dev_ctx, framework::TensorCopy(mask, platform::CPUPlace(), dev_ctx,
cpu_mask.get()); cpu_mask.get());
#else #else
......
...@@ -91,6 +91,16 @@ class SyncBatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -91,6 +91,16 @@ class SyncBatchNormGradKernel<platform::CUDADeviceContext, T>
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
sync_batch_norm, ops::SyncBatchNormKernel<plat::CUDADeviceContext, float>,
ops::SyncBatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
sync_batch_norm_grad,
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
#else
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sync_batch_norm, ops::SyncBatchNormKernel<plat::CUDADeviceContext, float>, sync_batch_norm, ops::SyncBatchNormKernel<plat::CUDADeviceContext, float>,
ops::SyncBatchNormKernel<plat::CUDADeviceContext, double>, ops::SyncBatchNormKernel<plat::CUDADeviceContext, double>,
...@@ -100,5 +110,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -100,5 +110,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, float>, ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, double>, ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, double>,
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, plat::float16>); ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
#endif
// clang-format on // clang-format on
...@@ -19,12 +19,19 @@ limitations under the License. */ ...@@ -19,12 +19,19 @@ limitations under the License. */
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#include "paddle/fluid/platform/miopen_helper.h"
#endif
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/norm_utils.h" #include "paddle/fluid/operators/norm_utils.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
...@@ -186,7 +193,7 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx, ...@@ -186,7 +193,7 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx,
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0); memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
#ifdef PADDLE_WITH_NCCL #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *comm = dev_ctx.nccl_comm(); auto *comm = dev_ctx.nccl_comm();
if (comm) { if (comm) {
int dtype = platform::ToNCCLDataType(mean_out->type()); int dtype = platform::ToNCCLDataType(mean_out->type());
...@@ -460,7 +467,7 @@ void SyncBatchNormGradFunctor( ...@@ -460,7 +467,7 @@ void SyncBatchNormGradFunctor(
dy_d, x_d, saved_mean, N, fsize, C, stats); dy_d, x_d, saved_mean, N, fsize, C, stats);
} }
#ifdef PADDLE_WITH_NCCL #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *comm = dev_ctx.nccl_comm(); auto *comm = dev_ctx.nccl_comm();
if (comm) { if (comm) {
int dtype = platform::ToNCCLDataType(scale->type()); int dtype = platform::ToNCCLDataType(scale->type());
......
...@@ -91,7 +91,7 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, ...@@ -91,7 +91,7 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim,
int64_t limit = x.numel(); int64_t limit = x.numel();
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
auto &cuda_dev_ctx = dynamic_cast<platform::CUDADeviceContext &>(dev_ctx); auto &cuda_dev_ctx = dynamic_cast<platform::CUDADeviceContext &>(dev_ctx);
functor(cuda_dev_ctx, &x, out, &ddx, &ddout, dout, dx); functor(cuda_dev_ctx, &x, out, &ddx, &ddout, dout, dx);
...@@ -105,7 +105,7 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, ...@@ -105,7 +105,7 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim,
platform::ForRange<platform::CPUDeviceContext> for_range(cpu_dev_ctx, platform::ForRange<platform::CPUDeviceContext> for_range(cpu_dev_ctx,
limit); limit);
for_range(actual_functor); for_range(actual_functor);
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
} }
#endif #endif
......
...@@ -16,11 +16,26 @@ limitations under the License. */ ...@@ -16,11 +16,26 @@ limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#ifdef __HIPCC__
namespace rocprim {
namespace detail {
template <>
struct radix_key_codec_base<paddle::platform::float16>
: radix_key_codec_integral<paddle::platform::float16, uint16_t> {};
} // namespace detail
} // namespace rocprim
namespace cub = hipcub;
#else
// set cub base traits in order to handle float16 // set cub base traits in order to handle float16
namespace cub { namespace cub {
template <> template <>
...@@ -28,6 +43,7 @@ struct NumericTraits<paddle::platform::float16> ...@@ -28,6 +43,7 @@ struct NumericTraits<paddle::platform::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, : BaseTraits<FLOATING_POINT, true, false, uint16_t,
paddle::platform::float16> {}; paddle::platform::float16> {};
} // namespace cub } // namespace cub
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -439,6 +455,16 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, ...@@ -439,6 +455,16 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows, input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream); cu_stream);
#ifdef __HIPCC__
if (err != hipSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch "
"hipcub::DeviceSegmentedRadixSort::SortPairsDescending to "
"calculate "
"temp_storage_bytes, status: "
<< hipGetErrorString(err);
return false;
}
#else
if (err != cudaSuccess) { if (err != cudaSuccess) {
LOG(ERROR) LOG(ERROR)
<< "TopKOP failed as could not launch " << "TopKOP failed as could not launch "
...@@ -447,12 +473,22 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, ...@@ -447,12 +473,22 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
<< cudaGetErrorString(err); << cudaGetErrorString(err);
return false; return false;
} }
#endif
} else { } else {
auto err = cub::DeviceSegmentedRadixSort::SortPairs( auto err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, input, sorted_values_ptr, nullptr, temp_storage_bytes, input, sorted_values_ptr,
input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows, input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream); cu_stream);
#ifdef __HIPCC__
if (err != hipSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch "
"hipcub::DeviceSegmentedRadixSort::SortPairs to calculate "
"temp_storage_bytes, status: "
<< hipGetErrorString(err);
return false;
}
#else
if (err != cudaSuccess) { if (err != cudaSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch " LOG(ERROR) << "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairs to calculate " "cub::DeviceSegmentedRadixSort::SortPairs to calculate "
...@@ -460,6 +496,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, ...@@ -460,6 +496,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
<< cudaGetErrorString(err); << cudaGetErrorString(err);
return false; return false;
} }
#endif
} }
Tensor temp_storage; Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes); temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
...@@ -470,6 +507,17 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, ...@@ -470,6 +507,17 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr, sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1, num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
0, sizeof(T) * 8, cu_stream); 0, sizeof(T) * 8, cu_stream);
#ifdef __HIPCC__
if (err != hipSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch "
"hipcub::DeviceSegmentedRadixSort::SortPairsDescending to "
"sort input, "
"temp_storage_bytes: "
<< temp_storage_bytes
<< ", status: " << hipGetErrorString(err);
return false;
}
#else
if (err != cudaSuccess) { if (err != cudaSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch " LOG(ERROR) << "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to " "cub::DeviceSegmentedRadixSort::SortPairsDescending to "
...@@ -479,12 +527,24 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, ...@@ -479,12 +527,24 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
<< ", status: " << cudaGetErrorString(err); << ", status: " << cudaGetErrorString(err);
return false; return false;
} }
#endif
} else { } else {
auto err = cub::DeviceSegmentedRadixSort::SortPairs( auto err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, input, temp_storage.data<uint8_t>(), temp_storage_bytes, input,
sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr, sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1, num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
0, sizeof(T) * 8, cu_stream); 0, sizeof(T) * 8, cu_stream);
#ifdef __HIPCC__
if (err != hipSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch "
"hipcub::DeviceSegmentedRadixSort::SortPairs to "
"sort input, "
"temp_storage_bytes: "
<< temp_storage_bytes
<< ", status: " << hipGetErrorString(err);
return false;
}
#else
if (err != cudaSuccess) { if (err != cudaSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch " LOG(ERROR) << "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairs to " "cub::DeviceSegmentedRadixSort::SortPairs to "
...@@ -494,6 +554,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, ...@@ -494,6 +554,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
<< ", status: " << cudaGetErrorString(err); << ", status: " << cudaGetErrorString(err);
return false; return false;
} }
#endif
} }
auto& dev = *ctx.eigen_device(); auto& dev = *ctx.eigen_device();
if (k < num_cols) { if (k < num_cols) {
......
...@@ -15,7 +15,12 @@ limitations under the License. */ ...@@ -15,7 +15,12 @@ limitations under the License. */
#pragma once #pragma once
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_function_cuda.h" #include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
......
...@@ -145,7 +145,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, ...@@ -145,7 +145,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context,
int64_t pos = std::abs(offset) * offset_stride; int64_t pos = std::abs(offset) * offset_stride;
int64_t dim_size = ret_strides.size(); int64_t dim_size = ret_strides.size();
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> diag_vec(vectorize(dig_stride)); thrust::device_vector<int64_t> diag_vec(vectorize(dig_stride));
const int64_t* diag_arr = thrust::raw_pointer_cast(diag_vec.data()); const int64_t* diag_arr = thrust::raw_pointer_cast(diag_vec.data());
thrust::device_vector<int64_t> ret_vec(ret_strides); thrust::device_vector<int64_t> ret_vec(ret_strides);
...@@ -238,7 +238,7 @@ class TraceGradKernel : public framework::OpKernel<T> { ...@@ -238,7 +238,7 @@ class TraceGradKernel : public framework::OpKernel<T> {
int64_t diag_size = len2 < len1 ? len2 : len1; int64_t diag_size = len2 < len1 ? len2 : len1;
int64_t pos = std::abs(offset) * offset_stride; int64_t pos = std::abs(offset) * offset_stride;
if (diag_size > 0) { if (diag_size > 0) {
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> output_vec(vectorize(output_stride)); thrust::device_vector<int64_t> output_vec(vectorize(output_stride));
const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data()); const int64_t* output_arr = thrust::raw_pointer_cast(output_vec.data());
thrust::device_vector<int64_t> input_vec(vectorize(input_stride)); thrust::device_vector<int64_t> input_vec(vectorize(input_stride));
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/functional.h> #include <thrust/functional.h>
#include <thrust/scatter.h> #include <thrust/scatter.h>
#include <thrust/sequence.h>
#include <thrust/unique.h> #include <thrust/unique.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include "paddle/fluid/framework/array.h" #include "paddle/fluid/framework/array.h"
#endif #endif
...@@ -103,7 +103,7 @@ class UnStackGradKernel : public framework::OpKernel<T> { ...@@ -103,7 +103,7 @@ class UnStackGradKernel : public framework::OpKernel<T> {
for (auto i = 0; i < axis; ++i) pre *= dim[i]; for (auto i = 0; i < axis; ++i) pre *= dim[i];
for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
int total_num = pre * n * post; int total_num = pre * n * post;
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
...@@ -156,14 +156,14 @@ class UnStackKernel : public framework::OpKernel<T> { ...@@ -156,14 +156,14 @@ class UnStackKernel : public framework::OpKernel<T> {
int post = total_num / (n * pre); int post = total_num / (n * pre);
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<T *> device_dx_vec(dx_datas); thrust::device_vector<T *> device_dx_vec(dx_datas);
auto dx_data_arr = device_dx_vec.data().get(); auto dx_data_arr = device_dx_vec.data().get();
#else #else
auto dx_data_arr = dx_datas.data(); auto dx_data_arr = dx_datas.data();
#endif #endif
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post); StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
// Wait() must be called because device_dx_vec may be destructed before // Wait() must be called because device_dx_vec may be destructed before
// kernel ends // kernel ends
dev_ctx.Wait(); dev_ctx.Wait();
......
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include <memory> #include <memory>
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif #endif
......
...@@ -159,6 +159,7 @@ class WarpCTCFunctor { ...@@ -159,6 +159,7 @@ class WarpCTCFunctor {
warpctc_version_ = platform::dynload::get_warpctc_version(); warpctc_version_ = platform::dynload::get_warpctc_version();
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
// HIP not support ctcOptions in third-party warpctc
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
options_.loc = CTC_GPU; options_.loc = CTC_GPU;
options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>( options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>(
......
...@@ -108,7 +108,11 @@ class CublasHandleHolder { ...@@ -108,7 +108,11 @@ class CublasHandleHolder {
} }
#endif #endif
#ifdef PADDLE_WITH_HIP
const rocblas_handle& GetCublasHandle() const { return handle_; }
#else
const cublasHandle_t& GetCublasHandle() const { return handle_; } const cublasHandle_t& GetCublasHandle() const { return handle_; }
#endif
~CublasHandleHolder() PADDLE_MAY_THROW { ~CublasHandleHolder() PADDLE_MAY_THROW {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
......
...@@ -459,9 +459,15 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { ...@@ -459,9 +459,15 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle(); return context()->CudnnHandle();
} }
#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle();
}
#else
cublasHandle_t CUDADeviceContext::cublas_handle() const { cublasHandle_t CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle(); return context()->CublasHandle()->GetCublasHandle();
} }
#endif
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
......
...@@ -409,8 +409,12 @@ class CUDADeviceContext : public DeviceContext { ...@@ -409,8 +409,12 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
#endif #endif
/*! \brief Return cublas handle in the device context. */ /*! \brief Return cublas handle in the device context. */
#ifdef PADDLE_WITH_HIP
rocblas_handle cublas_handle() const;
#else
cublasHandle_t cublas_handle() const; cublasHandle_t cublas_handle() const;
#endif
/*! \brief Return a cudnn workspace handle to call multiple cudnn /*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads. * functions without interrupting by other threads.
......
...@@ -47,7 +47,11 @@ TEST(Device, CUDADeviceContext) { ...@@ -47,7 +47,11 @@ TEST(Device, CUDADeviceContext) {
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
#endif #endif
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
#ifdef PADDLE_WITH_HIP
rocblas_handle cublas_handle = device_context->cublas_handle();
#else
cublasHandle_t cublas_handle = device_context->cublas_handle(); cublasHandle_t cublas_handle = device_context->cublas_handle();
#endif
ASSERT_NE(nullptr, cublas_handle); ASSERT_NE(nullptr, cublas_handle);
delete device_context; delete device_context;
} }
......
...@@ -37,9 +37,9 @@ namespace platform { ...@@ -37,9 +37,9 @@ namespace platform {
using framework::Tensor; using framework::Tensor;
template <typename T> template <typename T>
inline miopenDataType_t ToMIOpenDataType(const T& t) { inline miopenDataType_t ToCudnnDataType(const T& t) {
auto type = framework::ToDataType(t); auto type = framework::ToDataType(t);
return ToMIOpenDataType(type); return ToCudnnDataType(type);
} }
inline std::vector<int> TransformDimOrder(const std::vector<int>& dims) { inline std::vector<int> TransformDimOrder(const std::vector<int>& dims) {
...@@ -66,7 +66,7 @@ inline std::vector<int> TransformDimOrder(const std::vector<int>& dims) { ...@@ -66,7 +66,7 @@ inline std::vector<int> TransformDimOrder(const std::vector<int>& dims) {
} }
template <> template <>
inline miopenDataType_t ToMIOpenDataType( inline miopenDataType_t ToCudnnDataType(
const framework::proto::VarType::Type& t) { const framework::proto::VarType::Type& t) {
miopenDataType_t type = miopenFloat; miopenDataType_t type = miopenFloat;
switch (t) { switch (t) {
...@@ -84,37 +84,54 @@ inline miopenDataType_t ToMIOpenDataType( ...@@ -84,37 +84,54 @@ inline miopenDataType_t ToMIOpenDataType(
class ActivationDescriptor { class ActivationDescriptor {
public: public:
using T = miopenActivationDescriptor;
struct Deleter {
void operator()(T* t) {
if (t != nullptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenDestroyActivationDescriptor(t));
t = nullptr;
}
}
};
ActivationDescriptor() { ActivationDescriptor() {
T* raw_ptr;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenCreateActivationDescriptor(&desc_)); dynload::miopenCreateActivationDescriptor(&raw_ptr));
} desc_.reset(raw_ptr);
~ActivationDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenDestroyActivationDescriptor(desc_));
} }
template <typename T> template <typename T>
void set(miopenActivationMode_t mode, const T& coef) { void set(miopenActivationMode_t mode, const T& coef) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetActivationDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetActivationDescriptor(
desc_, mode, static_cast<double>(coef), 0.0, 0.0)); desc_.get(), mode, static_cast<double>(coef), 0.0, 0.0));
} }
miopenActivationDescriptor_t desc() { return desc_; } T* desc() { return desc_.get(); }
miopenActivationDescriptor_t desc() const { return desc_; } T* desc() const { return desc_.get(); }
private: private:
miopenActivationDescriptor_t desc_; std::unique_ptr<T, Deleter> desc_;
}; };
class TensorDescriptor { class TensorDescriptor {
public: public:
using T = miopenTensorDescriptor;
struct Deleter {
void operator()(T* t) {
if (t != nullptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(t));
t = nullptr;
}
}
};
TensorDescriptor() { TensorDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreateTensorDescriptor(&desc_)); T* raw_ptr;
} PADDLE_ENFORCE_CUDA_SUCCESS(
~TensorDescriptor() { dynload::miopenCreateTensorDescriptor(&raw_ptr));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(desc_)); desc_.reset(raw_ptr);
} }
miopenTensorDescriptor_t desc() { return desc_; } T* desc() { return desc_.get(); }
miopenTensorDescriptor_t desc() const { return desc_; } T* desc() const { return desc_.get(); }
void set(const Tensor& tensor, const int groups = 1) { void set(const Tensor& tensor, const int groups = 1) {
auto dims = framework::vectorize<int>(tensor.dims()); auto dims = framework::vectorize<int>(tensor.dims());
...@@ -128,7 +145,7 @@ class TensorDescriptor { ...@@ -128,7 +145,7 @@ class TensorDescriptor {
dims_with_group[1] = dims_with_group[1] / groups; dims_with_group[1] = dims_with_group[1] / groups;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor(
desc_, ToMIOpenDataType(tensor.type()), (miopenTensorDescriptor_t)(desc_.get()), ToCudnnDataType(tensor.type()),
static_cast<int>(dims_with_group.size()), static_cast<int>(dims_with_group.size()),
const_cast<int*>(dims_with_group.data()), const_cast<int*>(dims_with_group.data()),
const_cast<int*>(strides.data()))); const_cast<int*>(strides.data())));
...@@ -136,6 +153,9 @@ class TensorDescriptor { ...@@ -136,6 +153,9 @@ class TensorDescriptor {
void set(const Tensor& tensor, const miopenTensorFormat_t format) { void set(const Tensor& tensor, const miopenTensorFormat_t format) {
const int groups = 1; const int groups = 1;
PADDLE_ENFORCE_EQ(format, MIOPEN_TENSOR_NCHW,
platform::errors::InvalidArgument(
"format should ONLY be NCHW in MIOPEN."));
auto dims = framework::vectorize<int>(tensor.dims()); auto dims = framework::vectorize<int>(tensor.dims());
std::vector<int> strides(dims.size()); std::vector<int> strides(dims.size());
strides[dims.size() - 1] = 1; strides[dims.size() - 1] = 1;
...@@ -147,26 +167,35 @@ class TensorDescriptor { ...@@ -147,26 +167,35 @@ class TensorDescriptor {
dims_with_group[1] = dims_with_group[1] / groups; dims_with_group[1] = dims_with_group[1] / groups;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor(
desc_, ToMIOpenDataType(tensor.type()), (miopenTensorDescriptor_t)(desc_.get()), ToCudnnDataType(tensor.type()),
static_cast<int>(dims_with_group.size()), static_cast<int>(dims_with_group.size()),
const_cast<int*>(dims_with_group.data()), const_cast<int*>(dims_with_group.data()),
const_cast<int*>(strides.data()))); const_cast<int*>(strides.data())));
} }
private: private:
miopenTensorDescriptor_t desc_; std::unique_ptr<T, Deleter> desc_;
}; };
class FilterDescriptor { class FilterDescriptor {
public: public:
using T = miopenTensorDescriptor;
struct Deleter {
void operator()(T* t) {
if (t != nullptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(t));
t = nullptr;
}
}
};
FilterDescriptor() { FilterDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreateTensorDescriptor(&desc_)); T* raw_ptr;
} PADDLE_ENFORCE_CUDA_SUCCESS(
~FilterDescriptor() { dynload::miopenCreateTensorDescriptor(&raw_ptr));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(desc_)); desc_.reset(raw_ptr);
} }
miopenTensorDescriptor_t desc() { return desc_; } T* desc() { return desc_.get(); }
miopenTensorDescriptor_t desc() const { return desc_; } T* desc() const { return desc_.get(); }
void set(const Tensor& tensor, const miopenTensorFormat_t format, void set(const Tensor& tensor, const miopenTensorFormat_t format,
const int groups = 1) { const int groups = 1) {
...@@ -176,45 +205,55 @@ class FilterDescriptor { ...@@ -176,45 +205,55 @@ class FilterDescriptor {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"format should ONLY be NCHW in MIOPEN.")); "format should ONLY be NCHW in MIOPEN."));
transformed_dims = dims; transformed_dims = dims;
if (groups > 1) { // if (groups > 1) {
transformed_dims[1] = transformed_dims[1] / groups; // transformed_dims[1] = transformed_dims[1] / groups;
} // }
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSet4dTensorDescriptor(
desc_, ToMIOpenDataType(tensor.type()), (miopenTensorDescriptor_t)desc_.get(), ToCudnnDataType(tensor.type()),
static_cast<int>(transformed_dims.size()), transformed_dims[0], transformed_dims[1], transformed_dims[2],
const_cast<int*>(transformed_dims.data()), nullptr)); transformed_dims[3]));
} }
private: private:
miopenTensorDescriptor_t desc_; std::unique_ptr<T, Deleter> desc_;
}; };
class ConvolutionDescriptor { class ConvolutionDescriptor {
public: public:
using T = miopenConvolutionDescriptor;
struct Deleter {
void operator()(T* t) {
if (t != nullptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenDestroyConvolutionDescriptor(t));
t = nullptr;
}
}
};
ConvolutionDescriptor() { ConvolutionDescriptor() {
T* raw_ptr;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenCreateConvolutionDescriptor(&desc_)); dynload::miopenCreateConvolutionDescriptor(&raw_ptr));
} desc_.reset(raw_ptr);
~ConvolutionDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenDestroyConvolutionDescriptor(desc_));
} }
miopenConvolutionDescriptor_t desc() { return desc_; } T* desc() { return desc_.get(); }
miopenConvolutionDescriptor_t desc() const { return desc_; } T* desc() const { return desc_.get(); }
void set(miopenDataType_t dtype, const std::vector<int>& pads, void set(miopenDataType_t dtype, const std::vector<int>& pads,
const std::vector<int>& strides, const std::vector<int>& dilations, const std::vector<int>& strides, const std::vector<int>& dilations,
bool allow_tf32, const int groups = 1) { bool allow_tf32, const int groups = 1) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenInitConvolutionNdDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenInitConvolutionNdDescriptor(
desc_, static_cast<int>(pads.size()), const_cast<int*>(pads.data()), (miopenConvolutionDescriptor_t)desc_.get(),
static_cast<int>(pads.size()), const_cast<int*>(pads.data()),
const_cast<int*>(strides.data()), const_cast<int*>(dilations.data()), const_cast<int*>(strides.data()), const_cast<int*>(dilations.data()),
miopenConvolution)); miopenConvolution));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenSetConvolutionGroupCount(desc_, groups)); platform::dynload::miopenSetConvolutionGroupCount(
(miopenConvolutionDescriptor_t)desc_.get(), groups));
} }
private: private:
miopenConvolutionDescriptor_t desc_; std::unique_ptr<T, Deleter> desc_;
}; };
} // namespace platform } // namespace platform
......
...@@ -43,23 +43,6 @@ typedef enum { ...@@ -43,23 +43,6 @@ typedef enum {
MIOPEN_TENSOR_NHWC = 1, MIOPEN_TENSOR_NHWC = 1,
} miopenTensorFormat_t; } miopenTensorFormat_t;
// MIOPEN do not support indirect function call defined in cudnnWorkspaceHandle
struct miopenWorkspace {
explicit miopenWorkspace(size_t size) : size(size), data(NULL) {
PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&data, size));
}
miopenWorkspace(const miopenWorkspace&) = delete;
miopenWorkspace(miopenWorkspace&&) = default;
miopenWorkspace& operator=(miopenWorkspace&&) = default;
~miopenWorkspace() {
if (data) {
hipFree(data);
}
}
size_t size;
void* data;
};
inline const char* miopenGetErrorString(miopenStatus_t status) { inline const char* miopenGetErrorString(miopenStatus_t status) {
switch (status) { switch (status) {
case miopenStatusSuccess: case miopenStatusSuccess:
......
...@@ -984,7 +984,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -984,7 +984,7 @@ void BindImperative(py::module *m_ptr) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Imperative allreduce is not supported when paddle is " "Imperative allreduce is not supported when paddle is "
"not compiled with NCCL.")); "not compiled with NCCL."));
#endif // PADDLE_WITH_NCCL #endif // PADDLE_WITH_NCCL or PADDLE_WITH_RCCL
} }
}, },
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
...@@ -1435,7 +1435,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -1435,7 +1435,7 @@ void BindImperative(py::module *m_ptr) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
#endif #endif
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
py::class_<imperative::NCCLParallelContext, imperative::ParallelContext, py::class_<imperative::NCCLParallelContext, imperative::ParallelContext,
std::shared_ptr<imperative::NCCLParallelContext>>( std::shared_ptr<imperative::NCCLParallelContext>>(
m, "NCCLParallelContext") m, "NCCLParallelContext")
......
...@@ -1125,7 +1125,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1125,7 +1125,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("get_fetch_list", .def("get_fetch_list",
[](Variable &self) { return self.GetMutable<FetchList>(); }, [](Variable &self) { return self.GetMutable<FetchList>(); },
py::return_value_policy::reference) py::return_value_policy::reference)
#if (defined(PADDLE_WITH_NCCL)) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
.def("get_communicator", .def("get_communicator",
[](Variable &self) -> platform::Communicator * { [](Variable &self) -> platform::Communicator * {
return self.GetMutable<platform::Communicator>(); return self.GetMutable<platform::Communicator>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册