提交 c8ed768c 编写于 作者: C chengduoZH

refine pool2d

上级 2398db5e
...@@ -1253,9 +1253,9 @@ def conv2d(input, ...@@ -1253,9 +1253,9 @@ def conv2d(input,
raise ValueError("num_channels must be divisible by groups.") raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups num_filter_channels = num_channels / groups
filter_size = utils.convert_to_list(filter_size, 2, 'conv2d.filter_size') filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'conv2d.stride') stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'conv2d.padding') padding = utils.convert_to_list(padding, 2, 'padding')
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
...@@ -1433,8 +1433,8 @@ def sequence_last_step(input): ...@@ -1433,8 +1433,8 @@ def sequence_last_step(input):
def pool2d(input, def pool2d(input,
pool_size, pool_size,
pool_type, pool_type,
pool_stride=None, pool_stride=1,
pool_padding=None, pool_padding=0,
global_pooling=False, global_pooling=False,
use_cudnn=True, use_cudnn=True,
name=None): name=None):
...@@ -1442,20 +1442,15 @@ def pool2d(input, ...@@ -1442,20 +1442,15 @@ def pool2d(input,
This function adds the operator for pooling in 2 dimensions, using the This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters. pooling configurations mentioned in input parameters.
""" """
if pool_padding is None:
pool_padding = [0, 0]
if pool_stride is None:
pool_stride = [1, 1]
if pool_type not in ["max", "avg"]: if pool_type not in ["max", "avg"]:
raise ValueError( raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
str(pool_type)) str(pool_type))
if isinstance(pool_size, int):
pool_size = [pool_size, pool_size] pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
if isinstance(pool_stride, int): pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding')
pool_stride = [pool_stride, pool_stride] pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride')
if isinstance(pool_padding, int):
pool_padding = [pool_padding, pool_padding]
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
...@@ -1782,9 +1777,9 @@ def conv2d_transpose(input, ...@@ -1782,9 +1777,9 @@ def conv2d_transpose(input,
raise TypeError("Input of conv2d_transpose must be Variable") raise TypeError("Input of conv2d_transpose must be Variable")
input_channel = input.shape[1] input_channel = input.shape[1]
padding = utils.convert_to_list(padding, 2, 'conv2d_transpose.padding') padding = utils.convert_to_list(padding, 2, 'padding')
stride = utils.convert_to_list(stride, 2, 'conv2d_transpose.stride') stride = utils.convert_to_list(stride, 2, 'stride')
dilation = utils.convert_to_list(dilation, 2, 'conv2d_transpose.dilation') dilation = utils.convert_to_list(dilation, 2, 'dilation')
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册