diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 9ae3e872812adaf00d6ead3f9fe32ecd9bef006e..63e4018555012ca5082ebbeec80c1fb5f4ebc97d 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -70,9 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); - if (paddle::platform::is_cpu_place(ctx.GetPlace())) { - use_cudnn = false; - } + use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index 46f79b1701f0f4f1fad83bff0abc422416099373..4145638c74ecae11cd30d1cf57a6d76c0dc5992e 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -61,9 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); - if (paddle::platform::is_cpu_place(ctx.GetPlace())) { - use_cudnn = false; - } + use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index 648a1dfb56ac4f85486cdc68e62dbbaed1664297..ebe7d9a0a581f187dcb7e37d2a2da1f26c8d72bb 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -64,9 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType PoolOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { bool use_cudnn = ctx.Attr("use_cudnn"); - if (paddle::platform::is_cpu_place(ctx.GetPlace())) { - use_cudnn = false; - } + use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); framework::LibraryType library_; if (use_cudnn) { library_ = framework::LibraryType::kCUDNN; diff --git a/python/paddle/v2/fluid/nets.py b/python/paddle/v2/fluid/nets.py index 327d2ff2da191cad3983381ad67cb5e82c89b0bb..440467e0abc9b74c2a80f68b44d9abd1a9a58991 100644 --- a/python/paddle/v2/fluid/nets.py +++ b/python/paddle/v2/fluid/nets.py @@ -41,10 +41,9 @@ def img_conv_group(input, param_attr=None, conv_with_batchnorm=False, conv_batchnorm_drop_rate=None, - conv_use_cudnn=True, pool_stride=1, pool_type=None, - pool_use_cudnn=True): + use_cudnn=True): """ Image Convolution Group, Used for vgg net. """ @@ -76,7 +75,7 @@ def img_conv_group(input, padding=conv_padding[i], param_attr=param_attr[i], act=local_conv_act, - use_cudnn=conv_use_cudnn) + use_cudnn=use_cudnn) if conv_with_batchnorm[i]: tmp = layers.batch_norm(input=tmp, act=conv_act) @@ -89,7 +88,7 @@ def img_conv_group(input, pool_size=pool_size, pool_type=pool_type, pool_stride=pool_stride, - use_cudnn=pool_use_cudnn) + use_cudnn=use_cudnn) return pool_out