提交 24f528a1 编写于 作者: C chengduoZH

follow comments

上级 251c6032
......@@ -70,9 +70,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
use_cudnn = false;
}
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
......
......@@ -61,9 +61,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
use_cudnn = false;
}
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
......
......@@ -64,9 +64,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
use_cudnn = false;
}
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
......
......@@ -41,10 +41,9 @@ def img_conv_group(input,
param_attr=None,
conv_with_batchnorm=False,
conv_batchnorm_drop_rate=None,
conv_use_cudnn=True,
pool_stride=1,
pool_type=None,
pool_use_cudnn=True):
use_cudnn=True):
"""
Image Convolution Group, Used for vgg net.
"""
......@@ -76,7 +75,7 @@ def img_conv_group(input,
padding=conv_padding[i],
param_attr=param_attr[i],
act=local_conv_act,
use_cudnn=conv_use_cudnn)
use_cudnn=use_cudnn)
if conv_with_batchnorm[i]:
tmp = layers.batch_norm(input=tmp, act=conv_act)
......@@ -89,7 +88,7 @@ def img_conv_group(input,
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
use_cudnn=pool_use_cudnn)
use_cudnn=use_cudnn)
return pool_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册