提交 79aa5122 编写于 作者: C chengduoZH

fix conv, pool, conv_trans to decide use cudnn or not

上级 78dc9343
......@@ -70,6 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
......@@ -283,6 +284,7 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/vol2col.h"
#include "paddle/platform/dynload/cudnn.h"
namespace paddle {
namespace operators {
......
......@@ -61,6 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
......@@ -263,6 +264,7 @@ void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/vol2col.h"
#include "paddle/platform/dynload/cudnn.h"
namespace paddle {
namespace operators {
......
......@@ -64,6 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
......@@ -88,6 +89,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::dynload::HasCUDNN();
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"
#include "paddle/platform/dynload/cudnn.h"
namespace paddle {
namespace operators {
......
......@@ -57,6 +57,10 @@ void EnforceCUDNNLoaded(const char* fn_name) {
bool HasCUDNN() { return true; }
#endif
#ifndef PADDLE_WITH_CUDA
bool HasCUDNN() { return false; }
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册