未验证 提交 89bbc5bc 编写于 作者: H Houjiang Chen 提交者: GitHub

Get device index from local rank if multi-client, otherwise use the current device. (#6405)

* Fix random generator

* Get device index from local rank if multi-client, otherwise use current device.
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 62737734
......@@ -70,7 +70,7 @@ Maybe<Generator> DefaultCUDAGenerator(int device_index) {
static std::vector<std::once_flag> init_flags(device_count);
static std::vector<std::shared_ptr<Generator>> default_cuda_generator(device_count);
if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }
if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); }
CHECK_OR_RETURN(device_index >= 0 && device_index < device_count)
<< "Invalid device index " << device_index;
std::call_once(init_flags[device_index], [&]() {
......@@ -91,7 +91,7 @@ Maybe<Generator> MakeCPUGenerator() {
#ifdef WITH_CUDA
Maybe<Generator> MakeCUDAGenerator(int device_index) {
if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }
if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); }
CHECK_OR_RETURN(device_index >= 0 && device_index < detail::GetCudaDeviceCount())
<< "Invalid device index " << device_index;
return std::make_shared<Generator>(
......
......@@ -145,10 +145,9 @@ int GetThreadNum(const cudaDeviceProp& prop) {
}
}
Maybe<void> CUDASynchronize(int device_index) {
Maybe<void> CUDASynchronize() {
// Synchronize cuda device to avoid state been modified in random kernels.
JUST(CPUSynchronize());
OF_CUDA_CHECK(cudaSetDevice(device_index));
OF_CUDA_CHECK(cudaDeviceSynchronize());
return Maybe<void>::Ok();
}
......@@ -161,25 +160,29 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(uint64_t seed, int device_index)
OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_index));
max_block_num_ = prop.multiProcessorCount;
max_thread_num_ = GetThreadNum(prop);
OF_CUDA_CHECK(cudaSetDevice(device_index));
CudaCurrentDeviceGuard dev_guard(device_index);
OF_CUDA_CHECK(
cudaMalloc(&curand_states_, max_block_num_ * max_thread_num_ * sizeof(curandState)));
detail::InitCurandStates(seed, max_block_num_, max_thread_num_, curand_states_);
}
CUDAGeneratorImpl::~CUDAGeneratorImpl() {
CHECK_JUST(CUDASynchronize(this->device_index()));
CudaCurrentDeviceGuard dev_guard(this->device_index());
CHECK_JUST(CUDASynchronize());
OF_CUDA_CHECK(cudaFree(curand_states_));
}
void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
CHECK_JUST(CUDASynchronize(this->device_index()));
CudaCurrentDeviceGuard dev_guard(this->device_index());
CHECK_JUST(CUDASynchronize());
seed_ = seed;
detail::InitCurandStates(seed_, max_block_num_, max_thread_num_, curand_states_);
}
Maybe<Tensor> CUDAGeneratorImpl::GetState() const {
JUST(CUDASynchronize(this->device_index()));
CudaCurrentDeviceGuard dev_guard(this->device_index());
JUST(CUDASynchronize());
int64_t state_size = max_block_num_ * max_thread_num_ * sizeof(curandState);
int64_t total_size = state_size + sizeof(int64_t);
const auto& device = JUST(Device::New("cpu"));
......@@ -207,7 +210,8 @@ Maybe<void> CUDAGeneratorImpl::SetState(const std::shared_ptr<Tensor>& tensor_st
<< total_size << ", but got " << tensor_state->shape()->elem_cnt();
}
JUST(CUDASynchronize(this->device_index()));
CudaCurrentDeviceGuard dev_guard(this->device_index());
JUST(CUDASynchronize());
const auto& callback = std::make_shared<std::function<void(uint64_t)>>([&](uint64_t of_blob_ptr) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
const int8_t* data = of_blob->blob().dptr<int8_t>();
......@@ -398,16 +402,27 @@ Maybe<CPUGeneratorImpl> MakeGeneratorImpl<CPUGeneratorImpl>(uint64_t seed, int d
}
#ifdef WITH_CUDA
int GetCudaDeviceIndex() {
int cuda_device_index = 0;
if (CHECK_JUST(GlobalMultiClientEnv())) {
cuda_device_index = GlobalProcessCtx::LocalRank();
} else {
OF_CUDA_CHECK(cudaGetDevice(&cuda_device_index));
}
return cuda_device_index;
}
int GetCudaDeviceCount() {
/* static */ int cuda_device_count;
OF_CUDA_CHECK(cudaSetDevice(GlobalProcessCtx::LocalRank()));
/* static */ int cuda_device_count = 0;
CudaCurrentDeviceGuard dev_guard(detail::GetCudaDeviceIndex());
OF_CUDA_CHECK(cudaGetDeviceCount(&cuda_device_count));
return cuda_device_count;
}
template<>
DeviceKey MakeDeviceKey<CUDAGeneratorImpl>(int device_index) {
if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); }
if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); }
DeviceKey device_key;
device_key.device_type = DeviceType::kGPU;
device_key.device_index = device_index;
......
......@@ -137,6 +137,7 @@ class CUDAGeneratorImpl : public DeviceGeneratorImpl {
namespace detail {
int GetCudaDeviceIndex();
int GetCudaDeviceCount();
void InitCurandStates(uint64_t seed, int32_t block_num, int32_t thread_num, curandState* states);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册