未验证 提交 9f44af9d 编写于 作者: Y Yu Yang 提交者: GitHub

Fix #6460 (#6461)

上级 95924686
......@@ -762,7 +762,7 @@ def sequence_conv(input,
helper = LayerHelper('sequence_conv', **locals())
dtype = helper.input_dtype()
filter_shape = [filter_size * input.shape[1], num_filters]
filter = helper.create_parameter(
filter_param = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype)
pre_bias = helper.create_tmp_variable(dtype)
......@@ -770,7 +770,7 @@ def sequence_conv(input,
type='sequence_conv',
inputs={
'X': [input],
'Filter': [filter],
'Filter': [filter_param],
},
outputs={"Out": pre_bias},
attrs={
......@@ -785,7 +785,7 @@ def sequence_conv(input,
def conv2d(input,
num_filters,
filter_size,
stride=[1, 1],
stride=None,
padding=None,
groups=None,
param_attr=None,
......@@ -802,6 +802,8 @@ def conv2d(input,
conv-2d output, if mentioned in the input parameters.
"""
if stride is None:
stride = [1, 1]
helper = LayerHelper('conv2d', **locals())
dtype = helper.input_dtype()
......@@ -827,7 +829,7 @@ def conv2d(input,
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
return Normal(0.0, std, 0)
filter = helper.create_parameter(
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype,
......@@ -839,7 +841,7 @@ def conv2d(input,
type='conv2d_cudnn',
inputs={
'Input': input,
'Filter': filter,
'Filter': filter_param,
},
outputs={"Output": pre_bias},
attrs={'strides': stride,
......@@ -875,8 +877,8 @@ def sequence_pool(input, pool_type, **kwargs):
def pool2d(input,
pool_size,
pool_type,
pool_stride=[1, 1],
pool_padding=[0, 0],
pool_stride=None,
pool_padding=None,
global_pooling=False,
main_program=None,
startup_program=None):
......@@ -884,6 +886,10 @@ def pool2d(input,
This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters.
"""
if pool_padding is None:
pool_padding = [0, 0]
if pool_stride is None:
pool_stride = [1, 1]
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册