diff --git a/paddle/fluid/extension/include/ext_place.h b/paddle/fluid/extension/include/ext_place.h index c9ed40a382417f6a85c4213f1d5d965797863c92..91d4f41c2135145062fb98b2df77b14ec4e1a32b 100644 --- a/paddle/fluid/extension/include/ext_place.h +++ b/paddle/fluid/extension/include/ext_place.h @@ -17,6 +17,6 @@ limitations under the License. */ namespace paddle { // TODO(yangjiabin): Add other place support in next PR -enum class PlaceType { kUNK = -1, kCPU, kGPU, kHIP }; +enum class PlaceType { kUNK = -1, kCPU, kGPU }; } // namespace paddle diff --git a/paddle/fluid/extension/include/ext_tensor.h b/paddle/fluid/extension/include/ext_tensor.h index 7d13f56b02b8215d034405ee86ac79c4eaff6d6a..970be905cc2566f7de3bac1c593e237cf6c7dac0 100644 --- a/paddle/fluid/extension/include/ext_tensor.h +++ b/paddle/fluid/extension/include/ext_tensor.h @@ -16,8 +16,15 @@ limitations under the License. */ #include #include + #ifdef PADDLE_WITH_CUDA #include +using gpuStream_t = cudaStream_t; +#endif + +#ifdef PADDLE_WITH_HIP +#include +using gpuStream_t = hipStream_t; #endif #include "ext_dll_decl.h" // NOLINT @@ -126,11 +133,9 @@ class PD_DLL_DECL Tensor { /// \brief Check Tensor is initialized bool is_initialized() const; -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) /// \bref Get current stream of Tensor - cudaStream_t stream() const; -#elif defined(PADDLE_WITH_HIP) - hipStream_t stream() const; + gpuStream_t stream() const; #endif private: diff --git a/paddle/fluid/extension/src/ext_tensor.cc b/paddle/fluid/extension/src/ext_tensor.cc index a0a9872c4c29ccfb68fb0a3f09c648b52b1c69fc..b5cd9e0b5c0e1590f6455c317f7056317f46fab9 100644 --- a/paddle/fluid/extension/src/ext_tensor.cc +++ b/paddle/fluid/extension/src/ext_tensor.cc @@ -69,9 +69,9 @@ struct CastDataType { }; template -void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, - int64_t ele_size) { -#if defined(PADDLE_WITH_CUDA) +void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, + int64_t ele_size) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); int device_num = paddle::platform::GetCurrentDeviceId(); platform::CUDAPlace gpu_place(device_num); @@ -90,29 +90,11 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, PADDLE_THROW(platform::errors::Unavailable( "Only GPU related Copy can reach this func.")); } -#elif defined(PADDLE_WITH_HIP) - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - int device_num = paddle::platform::GetCurrentDeviceId(); - platform::CUDAPlace gpu_place(device_num); - auto *dev_ctx = - static_cast(pool.Get(gpu_place)); - if ((src_plc == PlaceType::kHIP) && (dst_plc == PlaceType::kCPU)) { - memory::Copy(platform::CPUPlace(), static_cast(dst), gpu_place, src, - ele_size, dev_ctx->stream()); - } else if ((src_plc == PlaceType::kHIP) && (dst_plc == PlaceType::kHIP)) { - memory::Copy(gpu_place, static_cast(dst), gpu_place, src, ele_size, - dev_ctx->stream()); - } else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kHIP)) { - memory::Copy(gpu_place, static_cast(dst), platform::CPUPlace(), src, - ele_size, dev_ctx->stream()); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "Only GPU related Copy can reach this func.")); - } +#ifdef PADDLE_WITH_HIP + hipStreamSynchronize(dev_ctx->stream()); #else - PADDLE_THROW(platform::errors::Unavailable( - "This function can only be used if compiled with" - "either -DWITH_ROCM=ON or -DWITH_GPU=ON")); + cudaStreamSynchronize(dev_ctx->stream()); +#endif #endif } @@ -175,16 +157,11 @@ T *Tensor::mutable_data() { case static_cast(PlaceType::kCPU): { return tensor->mutable_data(platform::CPUPlace()); } -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) case static_cast(PlaceType::kGPU): { int device_num = platform::GetCurrentDeviceId(); return tensor->mutable_data(platform::CUDAPlace(device_num)); } -#elif defined(PADDLE_WITH_HIP) - case static_cast(PlaceType::kHIP): { - int device_num = platform::GetCurrentDeviceId(); - return tensor->mutable_data(platform::CUDAPlace(device_num)); - } #endif default: PADDLE_THROW(platform::errors::Unavailable( @@ -245,23 +222,17 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const { target.reshape(shape()); auto *p_target_data = target.template mutable_data(); - bool supported_gpu_transform = false; -#if defined(PADDLE_WITH_CUDA) - supported_gpu_transform = - (src_place == PlaceType::kGPU && target_place == PlaceType::kCPU) || - (src_place == PlaceType::kCPU && target_place == PlaceType::kGPU) || - (src_place == PlaceType::kGPU && target_place == PlaceType::kGPU); -#elif defined(PADDLE_WITH_HIP) - supported_gpu_transform = - (src_place == PlaceType::kHIP && target_place == PlaceType::kCPU) || - (src_place == PlaceType::kCPU && target_place == PlaceType::kHIP) || - (src_place == PlaceType::kHIP && target_place == PlaceType::kHIP); -#endif - if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) { std::memcpy(static_cast(p_target_data), p_src_data, ele_size); - } else if (supported_gpu_transform) { - DeviceCopy(p_src_data, p_target_data, src_place, target_place, ele_size); + } else if ((src_place == PlaceType::kGPU) && + (target_place == PlaceType::kCPU)) { + GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); + } else if ((src_place == PlaceType::kCPU) && + (target_place == PlaceType::kGPU)) { + GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); + } else if ((src_place == PlaceType::kGPU) && + (target_place == PlaceType::kGPU)) { + GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); } else { PADDLE_THROW(platform::errors::Unavailable( "Not supported place transform of place: %d to place: %d", @@ -363,18 +334,15 @@ const PlaceType &Tensor::place() const { GET_CASTED_TENSOR; if (platform::is_cpu_place(tensor->place())) { place_ = PlaceType::kCPU; -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } else if (platform::is_gpu_place(tensor->place())) { place_ = PlaceType::kGPU; -#elif defined(PADDLE_WITH_HIP) - } else if (platform::is_gpu_place(tensor->place())) { - place_ = PlaceType::kHIP; #endif } else { PADDLE_THROW(platform::errors::Unimplemented( "Current Tensor hold unsupported Place Type, Please Init it" "using Tensor::mutable_data(PaddlePlace) with T among:" - "Place::kCPU or Place::kGPU or Place::kHIP")); + "Place::kCPU or Place::kGPU")); } return place_; } @@ -456,21 +424,16 @@ bool Tensor::is_initialized() const { } } -#define DEFINE_STREAM(_stream_t_) \ - _stream_t_ Tensor::stream() const { \ - if (!stream_.IsStreamSet()) { \ - PADDLE_THROW(platform::errors::PreconditionNotMet( \ - "Stream is not Set, only input tensor will have " \ - "stream which is set by framework ")); \ - } else { \ - return reinterpret_cast<_stream_t_>(stream_.GetStream()); \ - } \ +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +gpuStream_t Tensor::stream() const { + if (!stream_.IsStreamSet()) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Stream is not Set, only input tensor will have " + "stream which is set by framework ")); + } else { + return reinterpret_cast(stream_.GetStream()); } - -#if defined(PADDLE_WITH_CUDA) -DEFINE_STREAM(cudaStream_t) -#elif defined(PADDLE_WITH_HIP) -DEFINE_STREAM(hipStream_t) +} #endif namespace framework { diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 19e661587716b396ac1726b72ff483e5d1349d42..bb8258dcd9228f53e106642de19ae8ad96ab7eec 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -503,7 +503,7 @@ void RegisterOperatorKernel(const std::string& name, // but call api in gpu device, it will cause error. RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, PlaceType::kCPU, inputs, outputs, attrs); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, PlaceType::kGPU, inputs, outputs, attrs); #endif diff --git a/paddle/fluid/framework/custom_tensor_test.cc b/paddle/fluid/framework/custom_tensor_test.cc index 5d181bfb53bc910e900f24e8c1b3fac6e11484c9..342be27c896ae9924f967005f38315f70abed6f2 100644 --- a/paddle/fluid/framework/custom_tensor_test.cc +++ b/paddle/fluid/framework/custom_tensor_test.cc @@ -38,7 +38,7 @@ void TestCopyTensor() { for (int64_t i = 0; i < t1.size(); i++) { CHECK_EQ(t1_cpu_cp.template data()[i], T(5)); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) VLOG(2) << "Do GPU copy test"; auto t1_gpu_cp = t1_cpu_cp.template copy_to(paddle::PlaceType::kGPU); CHECK((paddle::PlaceType::kGPU == t1_gpu_cp.place())); @@ -50,33 +50,16 @@ void TestCopyTensor() { for (int64_t i = 0; i < t1.size(); i++) { CHECK_EQ(t1_gpu_cp_cp_cpu.template data()[i], T(5)); } -#elif defined(PADDLE_WITH_HIP) - VLOG(2) << "Do HIP copy test"; - auto t1_gpu_cp = t1_cpu_cp.template copy_to(paddle::PlaceType::kHIP); - CHECK((paddle::PlaceType::kHIP == t1_gpu_cp.place())); - auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to(paddle::PlaceType::kHIP); - CHECK((paddle::PlaceType::kHIP == t1_gpu_cp_cp.place())); - auto t1_gpu_cp_cp_cpu = - t1_gpu_cp_cp.template copy_to(paddle::PlaceType::kCPU); - CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place())); - for (int64_t i = 0; i < t1.size(); i++) { - CHECK_EQ(t1_gpu_cp_cp_cpu.template data()[i], T(5)); - } #endif } void TestAPIPlace() { std::vector tensor_shape = {5, 5}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto t1 = paddle::Tensor(paddle::PlaceType::kGPU); t1.reshape(tensor_shape); t1.mutable_data(); CHECK((paddle::PlaceType::kGPU == t1.place())); -#elif defined(PADDLE_WITH_HIP) - auto t1 = paddle::Tensor(paddle::PlaceType::kHIP); - t1.reshape(tensor_shape); - t1.mutable_data(); - CHECK((paddle::PlaceType::kHIP == t1.place())); #endif auto t2 = paddle::Tensor(paddle::PlaceType::kCPU); t2.reshape(tensor_shape); @@ -97,7 +80,7 @@ void TestAPISlice() { std::vector tensor_shape_sub1 = {3, 5}; std::vector tensor_shape_origin2 = {5, 5, 5}; std::vector tensor_shape_sub2 = {1, 5, 5}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto t1 = paddle::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin1); t1.mutable_data(); CHECK(t1.slice(0, 5).shape() == tensor_shape_origin1); @@ -144,7 +127,7 @@ void TestCast(paddle::DataType data_type) { t1.template mutable_data(); auto t2 = t1.cast(data_type); CHECK(t2.type() == data_type); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto tg1 = paddle::Tensor(paddle::PlaceType::kGPU); tg1.reshape(tensor_shape); tg1.template mutable_data(); diff --git a/paddle/fluid/framework/custom_tensor_utils.h b/paddle/fluid/framework/custom_tensor_utils.h index 809a6b965aad9bcb4594ecff99e460db723dfd53..d7bde04b84b1615d2f7e2f56e4b1218741faf674 100644 --- a/paddle/fluid/framework/custom_tensor_utils.h +++ b/paddle/fluid/framework/custom_tensor_utils.h @@ -18,11 +18,9 @@ limitations under the License. */ #include "paddle/fluid/extension/include/ext_tensor.h" #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/place.h" -#ifdef PADDLE_WITH_CUDA -#endif -#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace framework { @@ -110,7 +108,7 @@ class CustomTensorUtils { if (pc == PlaceType::kCPU) { return platform::Place(platform::CPUPlace()); } else if (pc == PlaceType::kGPU) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return platform::Place( platform::CUDAPlace(platform::GetCurrentDeviceId())); #endif @@ -127,7 +125,7 @@ class CustomTensorUtils { if (platform::is_cpu_place(pc)) { return PlaceType::kCPU; } else if (platform::is_gpu_place(pc)) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return PlaceType::kGPU; #endif } else { @@ -142,7 +140,7 @@ class CustomTensorUtils { static void SetTensorCurrentStream(paddle::Tensor* src, const platform::Place& pc) { if (platform::is_gpu_place(pc)) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto* dev_ctx = static_cast( platform::DeviceContextPool::Instance().Get(pc)); src->stream_.SetStream(reinterpret_cast(dev_ctx->stream())); diff --git a/python/paddle/utils/cpp_extension/cpp_extension.py b/python/paddle/utils/cpp_extension/cpp_extension.py index 19fa84046ed2d56b425e228b28326cf64a9dbbfc..5370de9ed42aa5f80e7a2b0da0b0c358f2a5d5d1 100644 --- a/python/paddle/utils/cpp_extension/cpp_extension.py +++ b/python/paddle/utils/cpp_extension/cpp_extension.py @@ -403,7 +403,7 @@ class BuildExtension(build_ext, object): cflags = copy.deepcopy(extra_postargs) try: original_compiler = self.compiler.compiler_so - # nvcc compile CUDA source + # nvcc or hipcc compile CUDA source if is_cuda_file(src): if core.is_compiled_with_rocm(): assert ROCM_HOME is not None, "Not found ROCM runtime, \ @@ -429,6 +429,13 @@ class BuildExtension(build_ext, object): elif isinstance(cflags, dict): cflags = cflags['cxx'] + # Note(qili93): HIP require some additional flags for CMAKE_C_FLAGS + if core.is_compiled_with_rocm(): + cflags.append('-D__HIP_PLATFORM_HCC__') + cflags.append('-D__HIP_NO_HALF_CONVERSIONS__=1') + cflags.append( + '-DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP') + # NOTE(Aurelius84): Since Paddle 2.0, we require gcc version > 5.x, # so we add this flag to ensure the symbol names from user compiled # shared library have same ABI suffix with core_(no)avx.so. @@ -436,7 +443,10 @@ class BuildExtension(build_ext, object): add_compile_flag(['-D_GLIBCXX_USE_CXX11_ABI=1'], cflags) # Append this macor only when jointly compiling .cc with .cu if not is_cuda_file(src) and self.contain_cuda_file: - cflags.append('-DPADDLE_WITH_CUDA') + if core.is_compiled_with_rocm(): + cflags.append('-DPADDLE_WITH_HIP') + else: + cflags.append('-DPADDLE_WITH_CUDA') add_std_without_repeat( cflags, self.compiler.compiler_type, use_std14=True) diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 0a2d71abfdee4f16a142ce1d100d27017681abca..5fee6630342895336b19fb66794bcb75366aea81 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -56,7 +56,12 @@ CLANG_LINK_FLAGS = [ MSVC_LINK_FLAGS = ['/MACHINE:X64'] -COMMON_NVCC_FLAGS = ['-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU'] +if core.is_compiled_with_rocm(): + COMMON_HIPCC_FLAGS = [ + '-DPADDLE_WITH_HIP', '-DEIGEN_USE_GPU', '-DEIGEN_USE_HIP' + ] +else: + COMMON_NVCC_FLAGS = ['-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU'] GCC_MINI_VERSION = (5, 4, 0) MSVC_MINI_VERSION = (19, 0, 24215) @@ -319,10 +324,14 @@ def prepare_unix_cudaflags(cflags): """ Prepare all necessary compiled flags for nvcc compiling CUDA files. """ - cflags = COMMON_NVCC_FLAGS + [ - '-ccbin', 'cc', '-Xcompiler', '-fPIC', '--expt-relaxed-constexpr', - '-DNVCC' - ] + cflags + get_cuda_arch_flags(cflags) + if core.is_compiled_with_rocm(): + cflags = COMMON_HIPCC_FLAGS + ['-Xcompiler', '-fPIC' + ] + cflags + get_rocm_arch_flags(cflags) + else: + cflags = COMMON_NVCC_FLAGS + [ + '-ccbin', 'cc', '-Xcompiler', '-fPIC', '--expt-relaxed-constexpr', + '-DNVCC' + ] + cflags + get_cuda_arch_flags(cflags) return cflags @@ -358,6 +367,14 @@ def get_cuda_arch_flags(cflags): return [] +def get_rocm_arch_flags(cflags): + """ + For ROCm platform, amdgpu target should be added for HIPCC. + """ + cflags = cflags + ['-fno-gpu-rdc', '-amdgpu-target=gfx906'] + return cflags + + def _get_fluid_path(): """ Return installed fluid dir path. @@ -471,7 +488,10 @@ def normalize_extension_kwargs(kwargs, use_cuda=False): add_compile_flag(extra_compile_args, ['-w']) # disable warning if use_cuda: - extra_link_args.append('-lcudart') + if core.is_compiled_with_rocm(): + extra_link_args.append('-lamdhip64') + else: + extra_link_args.append('-lcudart') kwargs['extra_link_args'] = extra_link_args