提交 78dc9343 编写于 作者: C chengduoZH

expose use_cudnn

上级 5ad1aef0
......@@ -660,6 +660,7 @@ def conv2d(input,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=False,
act=None):
"""
**Convlution2D Layer**
......@@ -758,6 +759,8 @@ def conv2d(input,
stride = [stride, stride]
if isinstance(padding, int):
padding = [padding, padding]
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
input_shape = input.shape
filter_shape = [num_filters, num_filter_channels] + filter_size
......@@ -781,9 +784,12 @@ def conv2d(input,
'Filter': filter_param,
},
outputs={"Output": pre_bias},
attrs={'strides': stride,
attrs={
'strides': stride,
'paddings': padding,
'groups': groups})
'groups': groups,
'use_cudnn': use_cudnn
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
......@@ -931,7 +937,8 @@ def pool2d(input,
pool_type,
pool_stride=None,
pool_padding=None,
global_pooling=False):
global_pooling=False,
use_cudnn=False):
"""
This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters.
......@@ -950,6 +957,8 @@ def pool2d(input,
pool_stride = [pool_stride, pool_stride]
if isinstance(pool_padding, int):
pool_padding = [pool_padding, pool_padding]
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
helper = LayerHelper('pool2d', **locals())
dtype = helper.input_dtype()
......@@ -964,7 +973,8 @@ def pool2d(input,
"ksize": pool_size,
"global_pooling": global_pooling,
"strides": pool_stride,
"paddings": pool_padding
"paddings": pool_padding,
"use_cudnn": use_cudnn
})
return pool_out
......@@ -1077,7 +1087,8 @@ def conv2d_transpose(input,
padding=None,
stride=None,
dilation=None,
param_attr=None):
param_attr=None,
use_cudnn=False):
"""
The transpose of conv2d layer.
......@@ -1132,6 +1143,10 @@ def conv2d_transpose(input,
elif dilation is not None:
op_attr['dilations'] = dilation
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
op_attr['use_cudnn'] = use_cudnn
if filter_size is None:
if output_size is None:
raise ValueError("output_size must be set when filter_size is None")
......
......@@ -13,19 +13,22 @@ def simple_img_conv_pool(input,
pool_stride,
act,
param_attr=None,
pool_type='max'):
pool_type='max',
use_cudnn=False):
conv_out = layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
param_attr=param_attr,
act=act)
act=act,
use_cudnn=use_cudnn)
pool_out = layers.pool2d(
input=conv_out,
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride)
pool_stride=pool_stride,
use_cudnn=use_cudnn)
return pool_out
......@@ -38,8 +41,10 @@ def img_conv_group(input,
param_attr=None,
conv_with_batchnorm=False,
conv_batchnorm_drop_rate=None,
conv_use_cudnn=False,
pool_stride=1,
pool_type=None):
pool_type=None,
pool_use_cudnn=False):
"""
Image Convolution Group, Used for vgg net.
"""
......@@ -70,7 +75,8 @@ def img_conv_group(input,
filter_size=conv_filter_size[i],
padding=conv_padding[i],
param_attr=param_attr[i],
act=local_conv_act)
act=local_conv_act,
use_cudnn=conv_use_cudnn)
if conv_with_batchnorm[i]:
tmp = layers.batch_norm(input=tmp, act=conv_act)
......@@ -82,7 +88,8 @@ def img_conv_group(input,
input=tmp,
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride)
pool_stride=pool_stride,
use_cudnn=pool_use_cudnn)
return pool_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册