From c9641a03dc232862f9f8015f39fc11eb30d81693 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 17 Jan 2018 15:18:49 +0800 Subject: [PATCH] refine code --- paddle/operators/conv_op.cc | 13 +++++++++++++ paddle/operators/conv_transpose_op.cc | 12 ++++++++++++ paddle/operators/pool_op.cc | 12 ++++++++++++ 3 files changed, 37 insertions(+) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 0e8dddd7f19..d6882b275b2 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -71,6 +71,12 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto& dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; @@ -285,6 +291,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto& dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif + framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index f71838c2aa7..a2382a7e42e 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -62,6 +62,12 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto& dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; @@ -265,6 +271,12 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto& dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index a4502794519..b97333bb1a1 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -65,6 +65,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; @@ -90,6 +96,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; -- GitLab