From 24f528a1a56e099fc9ebad146614855e3481531e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 17 Jan 2018 10:47:20 +0800 Subject: [PATCH] follow comments --- paddle/operators/conv_op.cc | 4 +--- paddle/operators/conv_transpose_op.cc | 4 +--- paddle/operators/pool_op.cc | 4 +--- python/paddle/v2/fluid/nets.py | 7 +++---- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 9ae3e87281..63e4018555 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -70,9 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); - if (paddle::platform::is_cpu_place(ctx.GetPlace())) { - use_cudnn = false; - } + use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index 46f79b1701..4145638c74 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -61,9 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); - if (paddle::platform::is_cpu_place(ctx.GetPlace())) { - use_cudnn = false; - } + use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index 648a1dfb56..ebe7d9a0a5 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -64,9 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType PoolOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); - if (paddle::platform::is_cpu_place(ctx.GetPlace())) { - use_cudnn = false; - } + use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/python/paddle/v2/fluid/nets.py b/python/paddle/v2/fluid/nets.py index 327d2ff2da..440467e0ab 100644 --- a/python/paddle/v2/fluid/nets.py +++ b/python/paddle/v2/fluid/nets.py @@ -41,10 +41,9 @@ def img_conv_group(input, param_attr=None, conv_with_batchnorm=False, conv_batchnorm_drop_rate=None, - conv_use_cudnn=True, pool_stride=1, pool_type=None, - pool_use_cudnn=True): + use_cudnn=True): """ Image Convolution Group, Used for vgg net. """ @@ -76,7 +75,7 @@ def img_conv_group(input, padding=conv_padding[i], param_attr=param_attr[i], act=local_conv_act, - use_cudnn=conv_use_cudnn) + use_cudnn=use_cudnn) if conv_with_batchnorm[i]: tmp = layers.batch_norm(input=tmp, act=conv_act) @@ -89,7 +88,7 @@ def img_conv_group(input, pool_size=pool_size, pool_type=pool_type, pool_stride=pool_stride, - use_cudnn=pool_use_cudnn) + use_cudnn=use_cudnn) return pool_out -- GitLab