提交 c8ed768c 编写于 作者: C chengduoZH

refine pool2d

上级 2398db5e
......@@ -1253,9 +1253,9 @@ def conv2d(input,
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups
filter_size = utils.convert_to_list(filter_size, 2, 'conv2d.filter_size')
stride = utils.convert_to_list(stride, 2, 'conv2d.stride')
padding = utils.convert_to_list(padding, 2, 'conv2d.padding')
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'padding')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
......@@ -1433,8 +1433,8 @@ def sequence_last_step(input):
def pool2d(input,
pool_size,
pool_type,
pool_stride=None,
pool_padding=None,
pool_stride=1,
pool_padding=0,
global_pooling=False,
use_cudnn=True,
name=None):
......@@ -1442,20 +1442,15 @@ def pool2d(input,
This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters.
"""
if pool_padding is None:
pool_padding = [0, 0]
if pool_stride is None:
pool_stride = [1, 1]
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
str(pool_type))
if isinstance(pool_size, int):
pool_size = [pool_size, pool_size]
if isinstance(pool_stride, int):
pool_stride = [pool_stride, pool_stride]
if isinstance(pool_padding, int):
pool_padding = [pool_padding, pool_padding]
pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding')
pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
......@@ -1782,9 +1777,9 @@ def conv2d_transpose(input,
raise TypeError("Input of conv2d_transpose must be Variable")
input_channel = input.shape[1]
padding = utils.convert_to_list(padding, 2, 'conv2d_transpose.padding')
stride = utils.convert_to_list(stride, 2, 'conv2d_transpose.stride')
dilation = utils.convert_to_list(dilation, 2, 'conv2d_transpose.dilation')
padding = utils.convert_to_list(padding, 2, 'padding')
stride = utils.convert_to_list(stride, 2, 'stride')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册