提交 f214e146 编写于 作者: M Megvii Engine Team

refactor(mgb/cuda): use single implementation of get_device_prop from utils

GitOrigin-RevId: 5cc95472b9f27339380f74a4a2828368af56c038
上级 54e79dd1
......@@ -46,7 +46,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle):
cuda_check(cudaGetDevice(&dev_id));
}
m_device_id = dev_id;
cuda_check(cudaGetDeviceProperties(&m_device_prop, dev_id));
m_device_prop = get_device_prop(dev_id);
// Get stream from MegCore computing handle.
megdnn_assert(CUDNN_VERSION == cudnnGetVersion(),
"cudnn version mismatch: compiled with %d; detected %zu at runtime",
......@@ -80,7 +80,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle):
cuda_check(cudaStreamSynchronize(stream()));
// check tk1
m_is_tegra_k1 = (strcmp(m_device_prop.name, "GK20A") == 0);
m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0);
m_cusolver_handle = nullptr;
}
......@@ -104,7 +104,7 @@ void HandleImpl::ConstScalars::init() {
size_t HandleImpl::alignment_requirement() const {
auto &&prop = m_device_prop;
return std::max(prop.textureAlignment, prop.texturePitchAlignment);
return std::max(prop->textureAlignment, prop->texturePitchAlignment);
}
bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) {
......
......@@ -42,7 +42,7 @@ class HandleImpl: public HandleImplHelper {
bool check_cross_dev_copy_constraint(const TensorLayout &src) override;
const cudaDeviceProp& device_prop() const {
return m_device_prop;
return *m_device_prop;
}
template <typename Opr>
......@@ -137,7 +137,7 @@ class HandleImpl: public HandleImplHelper {
cusolverDnHandle_t m_cusolver_handle;
std::once_flag m_cusolver_initialized;
cudaDeviceProp m_device_prop;
const cudaDeviceProp* m_device_prop;
struct ConstScalars {
union FP16 {
......
......@@ -107,19 +107,26 @@ uint32_t cuda::safe_size_in_kern(size_t size) {
return size;
}
cudaDeviceProp cuda::current_device_prop() {
const cudaDeviceProp& cuda::current_device_prop() {
int dev;
cuda_check(cudaGetDevice(&dev));
megdnn_assert(dev < MAX_NR_DEVICE, "device number too large: %d", dev);
auto&& rec = device_prop_rec[dev];
return *(cuda::get_device_prop(dev));
}
const cudaDeviceProp* cuda::get_device_prop(int device) {
megdnn_assert(device < MAX_NR_DEVICE, "device number too large: %d",
device);
megdnn_assert(device >= 0, "device number must not be negative, got %d",
device);
auto&& rec = device_prop_rec[device];
if (!rec.init) {
std::lock_guard<std::mutex> lock(rec.mtx);
if (!rec.init) {
cuda_check(cudaGetDeviceProperties(&rec.prop, dev));
cuda_check(cudaGetDeviceProperties(&rec.prop, device));
rec.init = true;
}
}
return rec.prop;
return &(rec.prop);
}
bool cuda::is_compute_capability_required(int major, int minor) {
......
......@@ -52,7 +52,10 @@ static inline void CUDART_CB callback_free(cudaStream_t /* stream */,
}
//! get property of currently active device
cudaDeviceProp current_device_prop();
const cudaDeviceProp& current_device_prop();
//! get property of device specified by device
const cudaDeviceProp* get_device_prop(int device);
//! check compute capability satisfied with given sm version
bool is_compute_capability_required(int major, int minor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册