提交 79aa5122 编写于 作者: C chengduoZH

fix conv, pool, conv_trans to decide use cudnn or not

上级 78dc9343
...@@ -70,6 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -70,6 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_; framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
...@@ -283,6 +284,7 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -283,6 +284,7 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOpGrad::GetExpectedKernelType( framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_; framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/operators/math/im2col.h" #include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/vol2col.h" #include "paddle/operators/math/vol2col.h"
#include "paddle/platform/dynload/cudnn.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -61,6 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -61,6 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_; framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
...@@ -263,6 +264,7 @@ void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -263,6 +264,7 @@ void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_; framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/operators/math/im2col.h" #include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/vol2col.h" #include "paddle/operators/math/vol2col.h"
#include "paddle/platform/dynload/cudnn.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -64,6 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -64,6 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType( framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_; framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
...@@ -88,6 +89,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { ...@@ -88,6 +89,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_; framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h" #include "paddle/operators/math/pooling.h"
#include "paddle/platform/dynload/cudnn.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -57,6 +57,10 @@ void EnforceCUDNNLoaded(const char* fn_name) { ...@@ -57,6 +57,10 @@ void EnforceCUDNNLoaded(const char* fn_name) {
bool HasCUDNN() { return true; } bool HasCUDNN() { return true; }
#endif #endif
#ifndef PADDLE_WITH_CUDA
bool HasCUDNN() { return false; }
#endif
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册