提交 a77fcef3 编写于 作者: Q qijun

fix cuda compile error

上级 a30754b0
......@@ -3,7 +3,6 @@
namespace paddle {
namespace dyload {
namespace dynload {
std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr;
......@@ -67,8 +66,6 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH
} /* namespace dynload */
// clang-format on
#ifndef PADDLE_TYPE_DOUBLE
#define CUBLAS_GEAM dynload::cublasSgeam
......
......@@ -33,6 +33,15 @@ int GetDeviceCount(void) {
throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed");
return count;
}
int GetCurrentDeviceId(void) {
int device_id;
throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed");
return device_id;
}
void SetDeviceId(int device_id) {
throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed");
}
} // namespace platform
} // namespace paddle
......
......@@ -3,6 +3,8 @@
namespace paddle {
namespace dyload {
std::once_flag curand_dso_flag;
void *curand_dso_handle = nullptr;
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
......@@ -31,7 +33,8 @@ namespace dyload {
__macro(curandSetStream) \
__macro(curandSetPseudoRandomGeneratorSeed)\
__macro(curandGenerateUniform) \
__macro(curandGenerateUniformDouble)
__macro(curandGenerateUniformDouble) \
__macro(curandDestroyGenerator)
// clang-format on
CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP)
......
......@@ -83,10 +83,11 @@ class CudaDeviceContext : public DeviceContext {
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasCreate failed");
PADDLE_ENFORCE(
cublasSetStream(blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
paddle::dyload::cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasCreate failed");
PADDLE_ENFORCE(paddle::dyload::cublasSetStream(blas_handle_, stream_) ==
CUBLAS_STATUS_SUCCESS,
"cublasSetStream failed");
}
return blas_handle_;
......@@ -95,10 +96,11 @@ class CudaDeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnCreate failed");
PADDLE_ENFORCE(
cudnnSetStream(dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
paddle::dyload::cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnCreate failed");
PADDLE_ENFORCE(paddle::dyload::cudnnSetStream(dnn_handle_, stream_) ==
CUDNN_STATUS_SUCCESS,
"cudnnSetStream failed");
}
return dnn_handle_;
......@@ -107,16 +109,16 @@ class CudaDeviceContext : public DeviceContext {
curandGenerator_t curand_generator() {
if (!rand_generator_) {
DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(
curandCreateGenerator(&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
PADDLE_ENFORCE(paddle::dyload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
CURAND_STATUS_SUCCESS,
"curandCreateGenerator failed");
PADDLE_ENFORCE(
curandSetPseudoRandomGeneratorSeed(rand_generator_, random_seed_) ==
CURAND_STATUS_SUCCESS,
paddle::dyload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(
curandSetStream(rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
PADDLE_ENFORCE(paddle::dyload::curandSetStream(
rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
"curandSetStream failed");
}
return rand_generator_;
......@@ -125,18 +127,20 @@ class CudaDeviceContext : public DeviceContext {
~CudaDeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS,
PADDLE_ENFORCE(
paddle::dyload::cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasDestroy failed");
}
if (dnn_handle_) {
PADDLE_ENFORCE(cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS,
PADDLE_ENFORCE(
paddle::dyload::cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnDestroy failed");
}
if (rand_generator_) {
PADDLE_ENFORCE(
curandDestroyGenerator(rand_generator_) == CURAND_STATUS_SUCCESS,
PADDLE_ENFORCE(paddle::dyload::curandDestroyGenerator(rand_generator_) ==
CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed");
}
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "dynamic_loader.h"
#include <gflags/gflags.h>
#include "DynamicLoader.h"
#include "Logging.h"
#include <glog/logging.h>
DEFINE_string(cudnn_dir, "",
"Specify path for loading libcudnn.so. For instance, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册