提交 a77fcef3 编写于 作者: Q qijun

fix cuda compile error

上级 a30754b0
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
namespace paddle { namespace paddle {
namespace dyload { namespace dyload {
namespace dynload {
std::once_flag cublas_dso_flag; std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr; void *cublas_dso_handle = nullptr;
...@@ -67,8 +66,6 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) ...@@ -67,8 +66,6 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP #undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH #undef CUBLAS_BLAS_ROUTINE_EACH
} /* namespace dynload */
// clang-format on // clang-format on
#ifndef PADDLE_TYPE_DOUBLE #ifndef PADDLE_TYPE_DOUBLE
#define CUBLAS_GEAM dynload::cublasSgeam #define CUBLAS_GEAM dynload::cublasSgeam
......
...@@ -33,6 +33,15 @@ int GetDeviceCount(void) { ...@@ -33,6 +33,15 @@ int GetDeviceCount(void) {
throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed"); throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed");
return count; return count;
} }
int GetCurrentDeviceId(void) {
int device_id;
throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed");
return device_id;
}
void SetDeviceId(int device_id) {
throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed");
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
namespace paddle { namespace paddle {
namespace dyload { namespace dyload {
std::once_flag curand_dso_flag;
void *curand_dso_handle = nullptr;
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ #define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
...@@ -31,7 +33,8 @@ namespace dyload { ...@@ -31,7 +33,8 @@ namespace dyload {
__macro(curandSetStream) \ __macro(curandSetStream) \
__macro(curandSetPseudoRandomGeneratorSeed)\ __macro(curandSetPseudoRandomGeneratorSeed)\
__macro(curandGenerateUniform) \ __macro(curandGenerateUniform) \
__macro(curandGenerateUniformDouble) __macro(curandGenerateUniformDouble) \
__macro(curandDestroyGenerator)
// clang-format on // clang-format on
CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP)
......
...@@ -83,11 +83,12 @@ class CudaDeviceContext : public DeviceContext { ...@@ -83,11 +83,12 @@ class CudaDeviceContext : public DeviceContext {
cublasHandle_t cublas_handle() { cublasHandle_t cublas_handle() {
if (!blas_handle_) { if (!blas_handle_) {
DeviceGuard guard(gpu_place_); DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasCreate failed");
PADDLE_ENFORCE( PADDLE_ENFORCE(
cublasSetStream(blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, paddle::dyload::cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasSetStream failed"); "cublasCreate failed");
PADDLE_ENFORCE(paddle::dyload::cublasSetStream(blas_handle_, stream_) ==
CUBLAS_STATUS_SUCCESS,
"cublasSetStream failed");
} }
return blas_handle_; return blas_handle_;
} }
...@@ -95,11 +96,12 @@ class CudaDeviceContext : public DeviceContext { ...@@ -95,11 +96,12 @@ class CudaDeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle() { cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) { if (!dnn_handle_) {
DeviceGuard guard(gpu_place_); DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnCreate failed");
PADDLE_ENFORCE( PADDLE_ENFORCE(
cudnnSetStream(dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, paddle::dyload::cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnSetStream failed"); "cudnnCreate failed");
PADDLE_ENFORCE(paddle::dyload::cudnnSetStream(dnn_handle_, stream_) ==
CUDNN_STATUS_SUCCESS,
"cudnnSetStream failed");
} }
return dnn_handle_; return dnn_handle_;
} }
...@@ -107,17 +109,17 @@ class CudaDeviceContext : public DeviceContext { ...@@ -107,17 +109,17 @@ class CudaDeviceContext : public DeviceContext {
curandGenerator_t curand_generator() { curandGenerator_t curand_generator() {
if (!rand_generator_) { if (!rand_generator_) {
DeviceGuard guard(gpu_place_); DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::dyload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
CURAND_STATUS_SUCCESS,
"curandCreateGenerator failed");
PADDLE_ENFORCE( PADDLE_ENFORCE(
curandCreateGenerator(&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == paddle::dyload::curandSetPseudoRandomGeneratorSeed(
CURAND_STATUS_SUCCESS, rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
"curandCreateGenerator failed");
PADDLE_ENFORCE(
curandSetPseudoRandomGeneratorSeed(rand_generator_, random_seed_) ==
CURAND_STATUS_SUCCESS,
"curandSetPseudoRandomGeneratorSeed failed"); "curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE( PADDLE_ENFORCE(paddle::dyload::curandSetStream(
curandSetStream(rand_generator_, stream_) == CURAND_STATUS_SUCCESS, rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
"curandSetStream failed"); "curandSetStream failed");
} }
return rand_generator_; return rand_generator_;
} }
...@@ -125,19 +127,21 @@ class CudaDeviceContext : public DeviceContext { ...@@ -125,19 +127,21 @@ class CudaDeviceContext : public DeviceContext {
~CudaDeviceContext() { ~CudaDeviceContext() {
Wait(); Wait();
if (blas_handle_) { if (blas_handle_) {
PADDLE_ENFORCE(cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS, PADDLE_ENFORCE(
"cublasDestroy failed"); paddle::dyload::cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasDestroy failed");
} }
if (dnn_handle_) { if (dnn_handle_) {
PADDLE_ENFORCE(cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS, PADDLE_ENFORCE(
"cudnnDestroy failed"); paddle::dyload::cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnDestroy failed");
} }
if (rand_generator_) { if (rand_generator_) {
PADDLE_ENFORCE( PADDLE_ENFORCE(paddle::dyload::curandDestroyGenerator(rand_generator_) ==
curandDestroyGenerator(rand_generator_) == CURAND_STATUS_SUCCESS, CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed"); "curandDestroyGenerator failed");
} }
delete eigen_stream_; delete eigen_stream_;
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "dynamic_loader.h"
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include "DynamicLoader.h" #include <glog/logging.h>
#include "Logging.h"
DEFINE_string(cudnn_dir, "", DEFINE_string(cudnn_dir, "",
"Specify path for loading libcudnn.so. For instance, " "Specify path for loading libcudnn.so. For instance, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册