提交 24f528a1 编写于 作者: C chengduoZH

follow comments

上级 251c6032
...@@ -70,9 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -70,9 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
if (paddle::platform::is_cpu_place(ctx.GetPlace())) { use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
use_cudnn = false;
}
framework::LibraryType library_; framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
......
...@@ -61,9 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -61,9 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
if (paddle::platform::is_cpu_place(ctx.GetPlace())) { use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
use_cudnn = false;
}
framework::LibraryType library_; framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
......
...@@ -64,9 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -64,9 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType( framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.Attr<bool>("use_cudnn");
if (paddle::platform::is_cpu_place(ctx.GetPlace())) { use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
use_cudnn = false;
}
framework::LibraryType library_; framework::LibraryType library_;
if (use_cudnn) { if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
......
...@@ -41,10 +41,9 @@ def img_conv_group(input, ...@@ -41,10 +41,9 @@ def img_conv_group(input,
param_attr=None, param_attr=None,
conv_with_batchnorm=False, conv_with_batchnorm=False,
conv_batchnorm_drop_rate=None, conv_batchnorm_drop_rate=None,
conv_use_cudnn=True,
pool_stride=1, pool_stride=1,
pool_type=None, pool_type=None,
pool_use_cudnn=True): use_cudnn=True):
""" """
Image Convolution Group, Used for vgg net. Image Convolution Group, Used for vgg net.
""" """
...@@ -76,7 +75,7 @@ def img_conv_group(input, ...@@ -76,7 +75,7 @@ def img_conv_group(input,
padding=conv_padding[i], padding=conv_padding[i],
param_attr=param_attr[i], param_attr=param_attr[i],
act=local_conv_act, act=local_conv_act,
use_cudnn=conv_use_cudnn) use_cudnn=use_cudnn)
if conv_with_batchnorm[i]: if conv_with_batchnorm[i]:
tmp = layers.batch_norm(input=tmp, act=conv_act) tmp = layers.batch_norm(input=tmp, act=conv_act)
...@@ -89,7 +88,7 @@ def img_conv_group(input, ...@@ -89,7 +88,7 @@ def img_conv_group(input,
pool_size=pool_size, pool_size=pool_size,
pool_type=pool_type, pool_type=pool_type,
pool_stride=pool_stride, pool_stride=pool_stride,
use_cudnn=pool_use_cudnn) use_cudnn=use_cudnn)
return pool_out return pool_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册