未验证 提交 33429630 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] update fluid platform for rocm39 (part4), test=develop (#30936)

上级 a5c56d83
......@@ -16,13 +16,20 @@ endif()
if (WITH_GPU)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator)
endif()
if (WITH_ROCM)
hip_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
hip_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
hip_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator)
endif()
cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator)
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
if (WITH_GPU)
if (WITH_GPU OR WITH_ROCM)
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator)
elseif(WITH_XPU)
set(AllocatorFacadeDeps xpu_info)
......@@ -40,6 +47,16 @@ if (WITH_GPU)
cuda_allocator
device_context
memcpy)
elseif (WITH_ROCM)
hip_test(best_fit_allocator_test
SRCS best_fit_allocator_test.cc
best_fit_allocator_test.cu
DEPS best_fit_allocator
locked_allocator
cpu_allocator
cuda_allocator
device_context
memcpy)
else()
cc_test(best_fit_allocator_test
SRCS best_fit_allocator_test.cc
......@@ -57,7 +74,7 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy)
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator locked_allocator cpu_allocator)
if (WITH_TESTING)
if (WITH_GPU AND TARGET retry_allocator_test)
if ((WITH_GPU OR WITH_ROCM) AND TARGET retry_allocator_test)
target_link_libraries(retry_allocator_test cuda_allocator)
endif()
......
......@@ -12,8 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include <thread> // NOLINT
#include <vector>
......@@ -40,8 +47,13 @@ __global__ void kernel(float *x, int n) {
void CheckKernelOutput(float *x, int n) {
auto host_x = std::unique_ptr<float[]>(new float[n]);
for (int i = 0; i < n; ++i) {
#ifdef PADDLE_WITH_HIP
EXPECT_TRUE(hipSuccess == hipMemcpy(host_x.get(), x, n * sizeof(float),
hipMemcpyDeviceToHost));
#else
EXPECT_TRUE(cudaSuccess == cudaMemcpy(host_x.get(), x, n * sizeof(float),
cudaMemcpyDeviceToHost));
#endif
EXPECT_GE(host_x[i] + DELTA, 3.14159f * i);
EXPECT_LE(host_x[i] - DELTA, 3.14159f * i);
}
......@@ -53,13 +65,22 @@ void MultiStreamCompute(float **data, float **second_data,
AllocationPtr allocation_ptr = Alloc(ctx, N * sizeof(float));
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
*data = reinterpret_cast<float *>(allocation_ptr->ptr());
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, ctx.stream(), *data, N);
#else
kernel<<<1, 64, 0, ctx.stream()>>>(*data, N);
#endif
// allocate and compute on same stream again
allocation_ptr = Alloc(ctx, N * sizeof(float));
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
*second_data = reinterpret_cast<float *>(allocation_ptr->ptr());
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, ctx.stream(), *second_data,
N);
#else
kernel<<<1, 64, 0, ctx.stream()>>>(*second_data, N);
#endif
}
TEST(Malloc, CUDADeviceContextMultiStream) {
......@@ -75,8 +96,12 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
float *second_data[NUM_STREAMS];
CudaDevCtxVec dev_ctx;
// default stream
// default stream
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, 0, main_stream_data, N);
#else
kernel<<<1, 64>>>(main_stream_data, N);
#endif
main_stream_alloc_ptr.reset();
for (int i = 0; i < NUM_STREAMS; ++i) {
......@@ -85,7 +110,11 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]);
}
#ifdef PADDLE_WITH_HIP
EXPECT_TRUE(hipSuccess == hipDeviceSynchronize());
#else
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
#endif
for (int i = 0; i < NUM_STREAMS; ++i) {
CheckKernelOutput(data[i], N);
CheckKernelOutput(second_data[i], N);
......@@ -106,8 +135,12 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
CudaDevCtxVec dev_ctx;
std::vector<std::thread> threads;
// default stream
// default stream
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, 0, main_stream_data, N);
#else
kernel<<<1, 64>>>(main_stream_data, N);
#endif
main_stream_alloc_ptr.reset();
for (int i = 0; i < NUM_STREAMS; ++i) {
......@@ -120,8 +153,11 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
for (int i = 0; i < NUM_STREAMS; ++i) {
threads[i].join();
}
#ifdef PADDLE_WITH_HIP
EXPECT_TRUE(hipSuccess == hipDeviceSynchronize());
#else
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
#endif
for (int i = 0; i < NUM_STREAMS; ++i) {
CheckKernelOutput(data[i], N);
CheckKernelOutput(second_data[i], N);
......
......@@ -196,9 +196,22 @@ void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
}
#endif
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024; // 64K
#ifdef PADDLE_WITH_HIP
inline void SyncCUDAStream() {
#if !defined(_WIN32)
hipStreamSynchronize(0);
#else
hipError_t e_sync = hipSuccess;
while (e_sync = hipStreamQuery(0)) {
if (e_sync == hipErrorNotReady) continue;
break;
}
#endif
}
#else
inline void SyncCUDAStream() {
#if !defined(_WIN32)
cudaStreamSynchronize(0);
......@@ -210,6 +223,7 @@ inline void SyncCUDAStream() {
}
#endif
}
#endif
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
......@@ -228,10 +242,18 @@ void Copy<platform::CPUPlace, platform::CUDAPlace>(
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
#endif
// FIXME(zjl): do we really need it?
if (num <= kMaxGpuAsyncCopyBytes) {
SyncCUDAStream();
......@@ -250,10 +272,18 @@ void Copy<platform::CUDAPlace, platform::CPUPlace>(
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
#endif
// FIXME(zjl): do we really need it?
if (num <= kMaxGpuAsyncCopyBytes) {
SyncCUDAStream();
......@@ -273,10 +303,18 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
platform::SetDeviceId(src_place.device);
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToDevice, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToDevice);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
#endif
}
} else {
if (stream) {
......@@ -332,10 +370,18 @@ void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
#endif
}
}
......@@ -351,10 +397,18 @@ void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
<< dst_place << " by thream(" << stream << ")";
if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice, stream);
#else
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
#endif
} else {
platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
#ifdef PADDLE_WITH_HIP
platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
#endif
}
}
......
......@@ -41,27 +41,44 @@ float test_pinned_memory() {
const int iteration = 10;
// create event start and end
cudaEvent_t start_e, stop_e, copying_e;
gpuEvent_t start_e, stop_e, copying_e;
float elapsedTime = 0;
#ifdef PADDLE_WITH_HIP
hipEventCreate(&start_e);
hipEventCreate(&stop_e);
hipEventCreate(&copying_e);
#else
cudaEventCreate(&start_e);
cudaEventCreate(&stop_e);
cudaEventCreate(&copying_e);
#endif
// create computation stream, data copying stream
cudaStream_t computation_stream, copying_stream;
gpuStream_t computation_stream, copying_stream;
#ifdef PADDLE_WITH_HIP
hipStreamCreate(&computation_stream);
hipStreamCreate(&copying_stream);
#else
cudaStreamCreate(&computation_stream);
cudaStreamCreate(&copying_stream);
#endif
// create record event, pinned memory, gpu memory
std::vector<cudaEvent_t> record_event(iteration);
std::vector<gpuEvent_t> record_event(iteration);
std::vector<float*> input_pinned_mem(iteration);
std::vector<float*> gpu_mem(iteration);
std::vector<float*> output_pinned_mem(iteration);
// initial data
for (int j = 0; j < iteration; ++j) {
#ifdef PADDLE_WITH_HIP
hipEventCreateWithFlags(&record_event[j], hipEventDisableTiming);
hipEventCreate(&(record_event[j]));
#else
cudaEventCreateWithFlags(&record_event[j], cudaEventDisableTiming);
cudaEventCreate(&(record_event[j]));
#endif
input_pinned_mem[j] = static_cast<float*>(
paddle::memory::Alloc(cpu_place, data_size * sizeof(float)));
output_pinned_mem[j] = static_cast<float*>(
......@@ -74,7 +91,11 @@ float test_pinned_memory() {
}
}
#ifdef PADDLE_WITH_HIP
hipEventRecord(start_e, computation_stream);
#else
cudaEventRecord(start_e, computation_stream);
#endif
// computation
for (int m = 0; m < 30; ++m) {
......@@ -88,13 +109,21 @@ float test_pinned_memory() {
// call kernel on computation stream.
Kernel<<<4, 1024, 0, computation_stream>>>(gpu_mem[i], data_size);
#ifdef PADDLE_WITH_HIP
// record event_computation on computation stream
hipEventRecord(record_event[i], computation_stream);
// wait event_computation on copy stream.
// note: this operation is async.
hipStreamWaitEvent(copying_stream, record_event[i], 0);
#else
// record event_computation on computation stream
cudaEventRecord(record_event[i], computation_stream);
// wait event_computation on copy stream.
// note: this operation is async.
cudaStreamWaitEvent(copying_stream, record_event[i], 0);
#endif
// copy data GPU->CPU, on copy stream.
// note: this operation is async for pinned memory.
paddle::memory::Copy(cpu_place, output_pinned_mem[i], cuda_place,
......@@ -103,6 +132,16 @@ float test_pinned_memory() {
}
}
#ifdef PADDLE_WITH_HIP
hipEventRecord(copying_e, copying_stream);
hipStreamWaitEvent(computation_stream, copying_e, 0);
hipEventRecord(stop_e, computation_stream);
hipEventSynchronize(start_e);
hipEventSynchronize(stop_e);
hipEventElapsedTime(&elapsedTime, start_e, stop_e);
#else
cudaEventRecord(copying_e, copying_stream);
cudaStreamWaitEvent(computation_stream, copying_e, 0);
......@@ -111,6 +150,7 @@ float test_pinned_memory() {
cudaEventSynchronize(start_e);
cudaEventSynchronize(stop_e);
cudaEventElapsedTime(&elapsedTime, start_e, stop_e);
#endif
// std::cout << cpu_place << " "
// << "time consume:" << elapsedTime / 30 << std::endl;
......@@ -123,12 +163,22 @@ float test_pinned_memory() {
}
}
// destroy resource
// destroy resource
#ifdef PADDLE_WITH_HIP
hipEventDestroy(copying_e);
hipEventDestroy(start_e);
hipEventDestroy(stop_e);
#else
cudaEventDestroy(copying_e);
cudaEventDestroy(start_e);
cudaEventDestroy(stop_e);
#endif
for (int j = 0; j < 10; ++j) {
#ifdef PADDLE_WITH_HIP
hipEventDestroy((record_event[j]));
#else
cudaEventDestroy((record_event[j]));
#endif
paddle::memory::Free(cpu_place, input_pinned_mem[j]);
paddle::memory::Free(cpu_place, output_pinned_mem[j]);
paddle::memory::Free(cuda_place, gpu_mem[j]);
......
......@@ -21,7 +21,7 @@ size_t Alignment(size_t size, const platform::Place &place) {
if (platform::is_cpu_place(place)) {
alignment = CpuMinChunkSize();
} else {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
alignment = GpuMinChunkSize();
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/gpu_info.h"
#endif
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
......@@ -337,7 +338,7 @@ void* GetNVRTCDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.dylib", false);
#elif defined(PADDLE_WITH_HIP)
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhiprtc.so", false);
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false);
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.so", false);
#endif
......@@ -347,7 +348,7 @@ void* GetCUDADsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.dylib", false);
#elif defined(PADDLE_WITH_HIP)
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhip_hcc.so", false);
return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false);
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.so", false);
#endif
......
......@@ -45,6 +45,7 @@ extern bool HasNVRTC();
* include all needed hiprtc functions
**/
#define HIPRTC_ROUTINE_EACH(__macro) \
__macro(hiprtcVersion); \
__macro(hiprtcGetErrorString); \
__macro(hiprtcCompileProgram); \
__macro(hiprtcCreateProgram); \
......
......@@ -16,10 +16,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <miopen/miopen.h>
#include <miopen/version.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
#define MIOPEN_VERSION \
(MIOPEN_VERSION_MAJOR * 1000 + MIOPEN_VERSION_MINOR * 100 + \
MIOPEN_VERSION_PATCH) // NOLINT
namespace paddle {
namespace platform {
namespace dynload {
......
......@@ -46,6 +46,7 @@ extern bool HasCUDADriver();
* include all needed cuda driver functions
**/
#define ROCM_ROUTINE_EACH(__macro) \
__macro(hipDriverGetVersion); \
__macro(hipGetErrorString); \
__macro(hipModuleLoadData); \
__macro(hipModuleGetFunction); \
......
......@@ -18,6 +18,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include "paddle/fluid/platform/place.h"
namespace paddle {
......@@ -48,9 +51,9 @@ class Event {
void set_name(std::string name) { name_ = name; }
void set_role(EventRole role) { role_ = role; }
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifndef PADDLE_WITH_CUPTI
cudaEvent_t event() const { return event_; }
gpuEvent_t event() const { return event_; }
int device() const { return device_; }
#endif
#endif
......@@ -66,7 +69,7 @@ class Event {
EventRole role_{};
int64_t cpu_ns_;
bool visited_status_{false};
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUPTI
int64_t gpu_ns_ = 0;
......@@ -77,7 +80,7 @@ class Event {
private:
#else
cudaEvent_t event_ = nullptr;
gpuEvent_t event_ = nullptr;
int device_ = -1;
#endif
#endif
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "gflags/gflags.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#endif
......@@ -45,7 +45,7 @@ DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* CUDA related related FLAG
......@@ -84,7 +84,7 @@ DEFINE_string(selected_gpus, "",
"share-memory only.");
#endif
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* CUDNN related FLAG
......@@ -167,7 +167,7 @@ DEFINE_bool(cudnn_batchnorm_spatial_persistent, false,
"batch_norm, default is False.");
#endif
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* NCCL related FLAG
......@@ -377,7 +377,7 @@ DEFINE_double(
"Default use 50% of CPU memory as the pinned_memory for PaddlePaddle,"
"reserve the rest for page tables, etc");
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* Memory related FLAG
......
......@@ -40,7 +40,7 @@ struct ForRange<CPUDeviceContext> {
size_t limit_;
};
#ifdef __NVCC__
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
size_t idx = static_cast<size_t>(threadIdx.x);
......
......@@ -16,8 +16,10 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cupti.h"
#endif
#include "paddle/fluid/platform/device_context.h"
......@@ -92,6 +94,7 @@ bool InitGflags(std::vector<std::string> args) {
return successed;
}
#ifdef PADDLE_WITH_CUDA
void InitCupti() {
#ifdef PADDLE_WITH_CUPTI
if (FLAGS_multiple_of_cupti_buffer_size == 1) return;
......@@ -117,14 +120,17 @@ void InitCupti() {
#undef MULTIPLY_ATTR_VALUE
#endif
}
#endif
void InitDevices() {
// CUPTI attribute should be set before any CUDA context is created (see CUPTI
// documentation about CUpti_ActivityAttribute).
// CUPTI attribute should be set before any CUDA context is created (see CUPTI
// documentation about CUpti_ActivityAttribute).
#ifdef PADDLE_WITH_CUDA
InitCupti();
#endif
/*Init all available devices by default */
std::vector<int> devices;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
try {
// use user specified GPUs in single-node multi-process mode.
devices = platform::GetSelectedDevices();
......@@ -154,7 +160,7 @@ void InitDevices(const std::vector<int> devices) {
continue;
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
places.emplace_back(platform::CUDAPlace(devices[i]));
#endif
#ifdef PADDLE_WITH_XPU
......@@ -162,7 +168,7 @@ void InitDevices(const std::vector<int> devices) {
#endif
}
places.emplace_back(platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
places.emplace_back(platform::CUDAPinnedPlace());
#endif
platform::DeviceContextPool::Init(places);
......
......@@ -19,7 +19,8 @@ TEST(InitDevices, CPU) {
using paddle::framework::InitDevices;
using paddle::platform::DeviceContextPool;
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_XPU)
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_XPU) && \
!defined(PADDLE_WITH_HIP)
InitDevices();
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_EQ(pool.size(), 1U);
......@@ -30,7 +31,7 @@ TEST(InitDevices, CUDA) {
using paddle::framework::InitDevices;
using paddle::platform::DeviceContextPool;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int count = paddle::platform::GetCUDADeviceCount();
InitDevices();
DeviceContextPool& pool = DeviceContextPool::Instance();
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h"
// MIOPEN do not have epslion definition
#define CUDNN_BN_MIN_EPSILON 1e-05
namespace paddle {
namespace platform {
struct float16;
} // namespace platform
} // namespace paddle
DECLARE_bool(cudnn_deterministic);
namespace paddle {
namespace platform {
// MIOPEN only support NCHW, just for compatibility with CUDNN API
typedef enum {
MIOPEN_TENSOR_NCHW = 0,
MIOPEN_TENSOR_NHWC = 1,
} miopenTensorFormat_t;
// MIOPEN do not support indirect function call defined in cudnnWorkspaceHandle
struct miopenWorkspace {
explicit miopenWorkspace(size_t size) : size(size), data(NULL) {
PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&data, size));
}
miopenWorkspace(const miopenWorkspace&) = delete;
miopenWorkspace(miopenWorkspace&&) = default;
miopenWorkspace& operator=(miopenWorkspace&&) = default;
~miopenWorkspace() {
if (data) {
hipFree(data);
}
}
size_t size;
void* data;
};
inline const char* miopenGetErrorString(miopenStatus_t status) {
switch (status) {
case miopenStatusSuccess:
return "miopenStatusSuccess";
case miopenStatusNotInitialized:
return "miopenStatusNotInitialized";
case miopenStatusAllocFailed:
return "miopenStatusAllocFailed";
case miopenStatusBadParm:
return "miopenStatusBadParm";
case miopenStatusInternalError:
return "miopenStatusInternalError";
case miopenStatusInvalidValue:
return "miopenStatusInvalidValue";
case miopenStatusUnknownError:
return "miopenStatusUnknownError";
case miopenStatusNotImplemented:
return "miopenStatusNotImplemented";
default:
return "Unknown miopen error number";
}
}
// no use, but will have compiling error if not defined
#define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))
enum class DataLayout { // Not use
kNHWC,
kNCHW,
kNCDHW,
kNDHWC, // add, liyamei
kNCHW_VECT_C,
};
enum class PoolingMode {
kMaximum,
kMaximumDeterministic,
kAverageExclusive,
kAverageInclusive,
};
enum class ActivationMode {
kNone, // activation identity
kSigmoid,
kRelu,
kRelu6,
kReluX,
kTanh,
kBandPass,
};
inline miopenPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch (mode) {
case PoolingMode::kMaximumDeterministic:
return miopenPoolingMax;
case PoolingMode::kAverageExclusive:
return miopenPoolingAverage;
case PoolingMode::kAverageInclusive:
return miopenPoolingAverageInclusive;
case PoolingMode::kMaximum:
return miopenPoolingMax;
default:
PADDLE_THROW(
platform::errors::Unimplemented("Unexpected MIOPEN pooling mode."));
}
}
inline ActivationMode StringToActivationMode(const std::string& str) {
if (str == "identity") {
return ActivationMode::kNone;
} else if (str == "sigmoid") {
return ActivationMode::kSigmoid;
} else if (str == "relu") {
return ActivationMode::kRelu;
} else if (str == "relu6") {
return ActivationMode::kRelu6;
} else if (str == "relux") {
return ActivationMode::kReluX;
} else if (str == "tanh") {
return ActivationMode::kTanh;
} else if (str == "bandpass") {
return ActivationMode::kBandPass;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unknown MIOPEN activation string: %s.", str));
}
}
template <typename T>
class CudnnDataType;
template <>
class CudnnDataType<float16> {
public:
static const miopenDataType_t type = miopenHalf;
// The scaling param type is float for HALF and FLOAT tensors
using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
template <>
class CudnnDataType<float> {
public:
static const miopenDataType_t type = miopenFloat;
using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
inline miopenTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
switch (order) {
case DataLayout::kNHWC:
return MIOPEN_TENSOR_NHWC;
case DataLayout::kNCHW:
return MIOPEN_TENSOR_NCHW;
case DataLayout::kNCDHW:
return MIOPEN_TENSOR_NCHW;
case DataLayout::kNDHWC:
return MIOPEN_TENSOR_NHWC;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"MIOPEN has no equivalent dataLayout for input order."));
}
return MIOPEN_TENSOR_NCHW;
}
class ScopedTensorDescriptor {
public:
ScopedTensorDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreateTensorDescriptor(&desc_));
}
~ScopedTensorDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(desc_));
}
inline miopenTensorDescriptor_t descriptor(const miopenTensorFormat_t format,
const miopenDataType_t type,
const std::vector<int>& dims,
const int groups = 1) {
// the format is not used now, will add later
std::vector<int> strides(dims.size());
strides[dims.size() - 1] = 1;
for (int i = dims.size() - 2; i >= 0; i--) {
strides[i] = dims[i + 1] * strides[i + 1];
}
// Update tensor descriptor dims setting if groups > 1
// NOTE: Here, Assume using NCHW or NCDHW order
std::vector<int> dims_with_group(dims.begin(), dims.end());
if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups;
}
// MIOPEN ONLY support data layout of NCHW
PADDLE_ENFORCE_EQ(format, MIOPEN_TENSOR_NCHW,
platform::errors::InvalidArgument(
"format should ONLY be NCHW in MIOPEN."));
if (dims.size() == 4) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor(
desc_, type, dims_with_group.size(),
const_cast<int*>(dims_with_group.data()),
const_cast<int*>(strides.data())));
} else if (dims.size() == 5) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor(
desc_, type, dims_with_group.size(),
const_cast<int*>(dims_with_group.data()),
const_cast<int*>(strides.data())));
}
return desc_;
}
template <typename T>
inline miopenTensorDescriptor_t descriptor(const DataLayout& order,
const std::vector<int>& dims,
const int groups = 1) {
return descriptor(GetCudnnTensorFormat(order), CudnnDataType<T>::type, dims,
groups);
}
inline miopenTensorDescriptor_t descriptor(const miopenDataType_t miopen_type,
const std::vector<int>& dim,
const std::vector<int>& stride) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor(
desc_, miopen_type, dim.size(), const_cast<int*>(dim.data()),
const_cast<int*>(stride.data())));
return desc_;
}
template <typename T>
inline miopenTensorDescriptor_t descriptor(const std::vector<int>& dim,
const std::vector<int>& stride) {
return descriptor(CudnnDataType<T>::type, dim, stride);
}
inline miopenTensorDescriptor_t desc() { return desc_; }
private:
miopenTensorDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor);
};
class ScopedDropoutDescriptor {
public:
ScopedDropoutDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreateDropoutDescriptor(&desc_));
}
~ScopedDropoutDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyDropoutDescriptor(desc_));
}
inline miopenDropoutDescriptor_t descriptor(const miopenHandle_t& handle,
const platform::Place& place,
bool initialized,
float dropout_prob_,
framework::Tensor* dropout_state_,
int seed, size_t state_size) {
if (dropout_state_ == nullptr) { // for no dropout or test
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetDropoutDescriptor(
desc_, handle, 0 /* dropout */, nullptr, 0 /* state_size */,
0 /* seed */, false, false, MIOPEN_RNG_PSEUDO_XORWOW));
return desc_;
}
auto* dropout_state_data = dropout_state_->data<uint8_t>();
if (!initialized) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetDropoutDescriptor(
desc_, handle, dropout_prob_, dropout_state_data, state_size, seed,
false, false, MIOPEN_RNG_PSEUDO_XORWOW));
} else {
auto dropout_state_dims = dropout_state_->dims();
state_size = dropout_state_dims[0];
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenRestoreDropoutDescriptor(
desc_, handle, dropout_prob_, dropout_state_data, state_size, 0,
false, false, MIOPEN_RNG_PSEUDO_XORWOW));
}
return desc_;
}
inline miopenDropoutDescriptor_t desc() { return desc_; }
private:
miopenDropoutDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedDropoutDescriptor);
};
class ScopedRNNDescriptor {
public:
ScopedRNNDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreateRNNDescriptor(&desc_));
}
~ScopedRNNDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyRNNDescriptor(desc_));
}
inline miopenRNNDescriptor_t desc() { return desc_; }
private:
miopenRNNDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedRNNDescriptor);
};
class ScopedFilterDescriptor {
public:
ScopedFilterDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreateTensorDescriptor(&desc_));
}
~ScopedFilterDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyTensorDescriptor(desc_));
}
inline miopenTensorDescriptor_t descriptor(const miopenTensorFormat_t format,
const miopenDataType_t type,
const std::vector<int>& kernel,
const int groups = 1) {
// filter layout: MCHW(MCDHW), where M is the number of
// output image channels, C is the number of input image channels,
// D is the depth of the filter, H is the height of the filter, and W is the
// width of the filter.
std::vector<int> kernel_with_group(kernel.begin(), kernel.end());
if (groups > 1) {
kernel_with_group[0] /= groups;
// NOTE: input filter(C) of the filter is already asserted to be C/groups.
}
std::vector<int> stride_dim(kernel_with_group.size());
stride_dim.push_back(1);
for (int k = kernel_with_group.size() - 2; k >= 0; k--) {
stride_dim[k] = stride_dim[k + 1] * kernel_with_group[k + 1];
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetTensorDescriptor(
desc_, type, kernel_with_group.size(),
const_cast<int*>(kernel_with_group.data()),
const_cast<int*>(stride_dim.data())));
return desc_;
}
template <typename T>
inline miopenTensorDescriptor_t descriptor(const DataLayout& order,
const std::vector<int>& kernel,
const int groups = 1) {
return descriptor(GetCudnnTensorFormat(order), CudnnDataType<T>::type,
kernel, groups);
}
inline miopenTensorDescriptor_t desc() { return desc_; }
private:
miopenTensorDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedFilterDescriptor);
};
class ScopedConvolutionDescriptor {
public:
ScopedConvolutionDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenCreateConvolutionDescriptor(&desc_));
}
~ScopedConvolutionDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenDestroyConvolutionDescriptor(desc_));
}
inline miopenConvolutionDescriptor_t descriptor(
miopenDataType_t type, const std::vector<int>& pads,
const std::vector<int>& strides, const std::vector<int>& dilations) {
PADDLE_ENFORCE_EQ(pads.size(), strides.size(),
platform::errors::InvalidArgument(
"The size of pads and strides should be equal. But "
"received size of pads is %d, size of strides is %d.",
pads.size(), strides.size()));
PADDLE_ENFORCE_EQ(
pads.size(), dilations.size(),
platform::errors::InvalidArgument(
"The size of pads and dilations should be equal. But received size "
"of pads is %d, size of dilations is %d.",
pads.size(), dilations.size()));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenInitConvolutionNdDescriptor(
desc_, pads.size(), const_cast<int*>(pads.data()),
const_cast<int*>(strides.data()), const_cast<int*>(dilations.data()),
miopenConvolution));
return desc_;
}
template <typename T>
inline miopenConvolutionDescriptor_t descriptor(
const std::vector<int>& pads, const std::vector<int>& strides,
const std::vector<int>& dilations) {
return descriptor(CudnnDataType<T>::type, pads, strides, dilations);
}
private:
miopenConvolutionDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
};
class ScopedPoolingDescriptor {
public:
ScopedPoolingDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreatePoolingDescriptor(&desc_));
}
~ScopedPoolingDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyPoolingDescriptor(desc_));
}
inline miopenPoolingDescriptor_t descriptor(const PoolingMode& mode,
const std::vector<int>& kernel,
const std::vector<int>& pads,
const std::vector<int>& strides) {
PADDLE_ENFORCE_EQ(kernel.size(), pads.size(),
platform::errors::InvalidArgument(
"The size of kernel and pads should be equal. But "
"received size of kernel is %d, size of pads is %d.",
kernel.size(), pads.size()));
PADDLE_ENFORCE_EQ(
kernel.size(), strides.size(),
platform::errors::InvalidArgument(
"The size of kernel and strides should be equal. But "
"received size of kernel is %d, size of strides is %d.",
kernel.size(), strides.size()));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSet2dPoolingDescriptor(
desc_, GetPoolingMode(mode), kernel[0], kernel[1], pads[0], pads[1],
strides[0], strides[1]));
return desc_;
}
private:
miopenPoolingDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
};
class ScopedActivationDescriptor {
public:
ScopedActivationDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenCreateActivationDescriptor(&desc_));
}
~ScopedActivationDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::miopenDestroyActivationDescriptor(desc_));
}
template <typename T>
inline miopenActivationDescriptor_t descriptor(
const std::string& act, double value_max = static_cast<double>(0.)) {
double relu_ceiling = 0.0;
ActivationMode activation_mode = StringToActivationMode(act);
miopenActivationMode_t mode;
switch (activation_mode) {
case ActivationMode::kNone:
mode = miopenActivationPASTHRU;
break;
case ActivationMode::kRelu6:
relu_ceiling = 6.0;
mode = miopenActivationCLIPPEDRELU;
break;
case ActivationMode::kReluX:
relu_ceiling = value_max;
mode = miopenActivationCLIPPEDRELU;
break;
case ActivationMode::kRelu:
mode = miopenActivationRELU;
break;
case ActivationMode::kSigmoid:
mode = miopenActivationLOGISTIC;
break;
case ActivationMode::kTanh:
mode = miopenActivationTANH;
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unrecognized MIOPEN activation mode: %d.",
static_cast<int>(activation_mode)));
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetActivationDescriptor(
desc_, mode, relu_ceiling, 0.0, 0.0));
return desc_;
}
private:
miopenActivationDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor);
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_HIP
if (use_cudnn) {
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}
class ScopedCTCLossDescriptor {
public:
ScopedCTCLossDescriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenCreateCTCLossDescriptor(&desc_));
}
~ScopedCTCLossDescriptor() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenDestroyCTCLossDescriptor(desc_));
}
template <typename T>
inline miopenCTCLossDescriptor_t descriptor() {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenSetCTCLossDescriptor(
desc_, CudnnDataType<T>::type, 0, false));
return desc_;
}
private:
miopenCTCLossDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedCTCLossDescriptor);
};
} // namespace platform
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define GLOG_NO_ABBREVIATED_SEVERITIES
#define GOOGLE_GLOG_DLL_DECL
#include "paddle/fluid/platform/miopen_helper.h"
#include <gtest/gtest.h>
TEST(MIOpenHelper, ScopedTensorDescriptor) {
using paddle::platform::ScopedTensorDescriptor;
using paddle::platform::DataLayout;
ScopedTensorDescriptor tensor_desc;
std::vector<int> shape = {2, 4, 6, 6};
auto desc = tensor_desc.descriptor<float>(DataLayout::kNCHW, shape);
miopenDataType_t type;
int nd;
std::vector<int> dims(4);
std::vector<int> strides(4);
paddle::platform::dynload::miopenGetTensorDescriptor(desc, &type, dims.data(),
strides.data());
paddle::platform::dynload::miopenGetTensorDescriptorSize(desc, &nd);
EXPECT_EQ(nd, 4);
for (size_t i = 0; i < dims.size(); ++i) {
EXPECT_EQ(dims[i], shape[i]);
}
EXPECT_EQ(strides[3], 1);
EXPECT_EQ(strides[2], 6);
EXPECT_EQ(strides[1], 36);
EXPECT_EQ(strides[0], 144);
// test tensor5d: ScopedTensorDescriptor
ScopedTensorDescriptor tensor5d_desc;
std::vector<int> shape_5d = {2, 4, 6, 6, 6};
auto desc_5d = tensor5d_desc.descriptor<float>(DataLayout::kNCDHW, shape_5d);
std::vector<int> dims_5d(5);
std::vector<int> strides_5d(5);
paddle::platform::dynload::miopenGetTensorDescriptor(
desc_5d, &type, dims_5d.data(), strides_5d.data());
paddle::platform::dynload::miopenGetTensorDescriptorSize(desc_5d, &nd);
EXPECT_EQ(nd, 5);
for (size_t i = 0; i < dims_5d.size(); ++i) {
EXPECT_EQ(dims_5d[i], shape_5d[i]);
}
EXPECT_EQ(strides_5d[4], 1);
EXPECT_EQ(strides_5d[3], 6);
EXPECT_EQ(strides_5d[2], 36);
EXPECT_EQ(strides_5d[1], 216);
EXPECT_EQ(strides_5d[0], 864);
}
TEST(MIOpenHelper, ScopedConvolutionDescriptor) {
using paddle::platform::ScopedConvolutionDescriptor;
ScopedConvolutionDescriptor conv_desc;
std::vector<int> src_pads = {2, 2, 2};
std::vector<int> src_strides = {1, 1, 1};
std::vector<int> src_dilations = {1, 1, 1};
auto desc = conv_desc.descriptor<float>(src_pads, src_strides, src_dilations);
miopenConvolutionMode_t mode;
int nd;
std::vector<int> pads(3);
std::vector<int> strides(3);
std::vector<int> dilations(3);
paddle::platform::dynload::miopenGetConvolutionNdDescriptor(
desc, 3, &nd, pads.data(), strides.data(), dilations.data(), &mode);
EXPECT_EQ(nd, 3);
for (size_t i = 0; i < src_pads.size(); ++i) {
EXPECT_EQ(pads[i], src_pads[i]);
EXPECT_EQ(strides[i], src_strides[i]);
EXPECT_EQ(dilations[i], src_dilations[i]);
}
EXPECT_EQ(mode, miopenConvolution);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册