gpu_info.h 3.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#ifdef PADDLE_WITH_CUDA
L
liaogang 已提交
18

L
liaogang 已提交
19
#include <cuda_runtime.h>
L
liaogang 已提交
20
#include <stddef.h>
L
liaogang 已提交
21
#include <string>
22
#include <vector>
L
liaogang 已提交
23 24 25 26 27

namespace paddle {
namespace platform {

//! Get the total number of GPU devices in system.
28
int GetCUDADeviceCount();
L
liaogang 已提交
29

30 31 32
//! Get the compute capability of the ith GPU (format: major * 10 + minor)
int GetCUDAComputeCapability(int i);

C
chengduo 已提交
33 34 35 36 37 38
//! Get the runtime version of the ith GPU
int GetCUDARuntimeVersion(int id);

//! Get the driver version of the ith GPU
int GetCUDADriverVersion(int id);

39 40 41
//! Wheter the current device support TensorCore
bool TensorCoreAvailable();

C
chengduoZH 已提交
42 43 44 45 46 47
//! Get the MultiProcessors of the ith GPU.
int GetCUDAMultiProcessors(int i);

//! Get the MaxThreads of each MultiProcessor of the ith GPU.
int GetCUDAMaxThreadsPerMultiProcessor(int i);

48 49 50
//! Get the MaxThreads of each block of the ith GPU.
int GetCUDAMaxThreadsPerBlock(int i);

L
liaogang 已提交
51 52 53
//! Get the current GPU device id in system.
int GetCurrentDeviceId();

54 55 56
//! Get the maximum GridDim size for GPU buddy allocator.
dim3 GetGpuMaxGridDimSize(int);

57 58 59
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices();

L
liaogang 已提交
60 61 62
//! Set the GPU device id for next execution.
void SetDeviceId(int device_id);

Q
Qiao Longfei 已提交
63
//! Get the memory usage of current GPU device.
64
void GpuMemoryUsage(size_t *available, size_t *total);
L
liaogang 已提交
65

66 67 68 69
//! Get the available memory to allocate, which is the size of available gpu
//! minus reserving.
size_t GpuAvailableMemToAlloc();

L
liaogang 已提交
70 71 72
//! Get the maximum allocation size of current GPU device.
size_t GpuMaxAllocSize();

Z
zhhsplendid 已提交
73 74 75 76 77 78
//! Get the initial allocation size of current GPU device.
size_t GpuInitAllocSize();

//! Get the re-allocation size of current GPU device.
size_t GpuReallocSize();

L
liaogang 已提交
79 80 81 82 83 84
//! Get the minimum chunk size for GPU buddy allocator.
size_t GpuMinChunkSize();

//! Get the maximum chunk size for GPU buddy allocator.
size_t GpuMaxChunkSize();

L
liaogang 已提交
85 86 87 88
//! Copy memory from address src to dst asynchronously.
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream);

89 90 91 92 93 94 95 96 97 98 99
//! Copy memory from address src to dst synchronously.
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind);

//! Copy memory from one device to another device asynchronously.
void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
                        int src_device, size_t count, cudaStream_t stream);

//! Copy memory from one device to another device synchronously.
void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count);
L
liaogang 已提交
100

D
dzhwinter 已提交
101 102 103
//! Set memory dst with value count size asynchronously
void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream);

石晓伟 已提交
104 105 106
//! Blocks until stream has completed all operations.
void GpuStreamSync(cudaStream_t stream);

107 108 109
//! Raise error if status is not cudaSuccess or OOM, otherwise reset status.
void RaiseNonOutOfMemoryError(cudaError_t *status);

L
liaogang 已提交
110 111 112
}  // namespace platform
}  // namespace paddle

L
Luo Tao 已提交
113
#endif