From 79aa51229ad41c4633056ad0f7056bbb88648a87 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 15 Jan 2018 19:46:03 +0800 Subject: [PATCH] fix conv, pool, conv_trans to decide use cudnn or not --- paddle/operators/conv_op.cc | 2 ++ paddle/operators/conv_op.h | 1 + paddle/operators/conv_transpose_op.cc | 2 ++ paddle/operators/conv_transpose_op.h | 1 + paddle/operators/pool_op.cc | 2 ++ paddle/operators/pool_op.h | 1 + paddle/platform/dynload/cudnn.cc | 4 ++++ 7 files changed, 13 insertions(+) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 424eccdb7dc..a6d3ebb4f52 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -70,6 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= platform::dynload::HasCUDNN(); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; @@ -283,6 +284,7 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= platform::dynload::HasCUDNN(); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 5a8933e7915..2a6fad65826 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/operators/math/im2col.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/vol2col.h" +#include "paddle/platform/dynload/cudnn.h" namespace paddle { namespace operators { diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index cf4e8c0a303..53153c02963 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -61,6 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= platform::dynload::HasCUDNN(); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; @@ -263,6 +264,7 @@ void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= platform::dynload::HasCUDNN(); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index a42ade41b16..f43d652fe1f 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/operators/math/im2col.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/vol2col.h" +#include "paddle/platform/dynload/cudnn.h" namespace paddle { namespace operators { diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index 3e567efd082..e0f9ea7aa0d 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -64,6 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType PoolOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= platform::dynload::HasCUDNN(); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; @@ -88,6 +89,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType PoolOpGrad::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= platform::dynload::HasCUDNN(); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index c3d82ecbdeb..e0668f1360d 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/pooling.h" +#include "paddle/platform/dynload/cudnn.h" namespace paddle { namespace operators { diff --git a/paddle/platform/dynload/cudnn.cc b/paddle/platform/dynload/cudnn.cc index 701f6240fef..fafdf5e78a9 100644 --- a/paddle/platform/dynload/cudnn.cc +++ b/paddle/platform/dynload/cudnn.cc @@ -57,6 +57,10 @@ void EnforceCUDNNLoaded(const char* fn_name) { bool HasCUDNN() { return true; } #endif +#ifndef PADDLE_WITH_CUDA +bool HasCUDNN() { return false; } +#endif + } // namespace dynload } // namespace platform } // namespace paddle -- GitLab