未验证 提交 9f44af9d 编写于 作者: Y Yu Yang 提交者: GitHub

Fix #6460 (#6461)

上级 95924686
...@@ -762,7 +762,7 @@ def sequence_conv(input, ...@@ -762,7 +762,7 @@ def sequence_conv(input,
helper = LayerHelper('sequence_conv', **locals()) helper = LayerHelper('sequence_conv', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
filter_shape = [filter_size * input.shape[1], num_filters] filter_shape = [filter_size * input.shape[1], num_filters]
filter = helper.create_parameter( filter_param = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype) attr=helper.param_attr, shape=filter_shape, dtype=dtype)
pre_bias = helper.create_tmp_variable(dtype) pre_bias = helper.create_tmp_variable(dtype)
...@@ -770,7 +770,7 @@ def sequence_conv(input, ...@@ -770,7 +770,7 @@ def sequence_conv(input,
type='sequence_conv', type='sequence_conv',
inputs={ inputs={
'X': [input], 'X': [input],
'Filter': [filter], 'Filter': [filter_param],
}, },
outputs={"Out": pre_bias}, outputs={"Out": pre_bias},
attrs={ attrs={
...@@ -785,7 +785,7 @@ def sequence_conv(input, ...@@ -785,7 +785,7 @@ def sequence_conv(input,
def conv2d(input, def conv2d(input,
num_filters, num_filters,
filter_size, filter_size,
stride=[1, 1], stride=None,
padding=None, padding=None,
groups=None, groups=None,
param_attr=None, param_attr=None,
...@@ -802,6 +802,8 @@ def conv2d(input, ...@@ -802,6 +802,8 @@ def conv2d(input,
conv-2d output, if mentioned in the input parameters. conv-2d output, if mentioned in the input parameters.
""" """
if stride is None:
stride = [1, 1]
helper = LayerHelper('conv2d', **locals()) helper = LayerHelper('conv2d', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
...@@ -827,7 +829,7 @@ def conv2d(input, ...@@ -827,7 +829,7 @@ def conv2d(input,
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5 std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
return Normal(0.0, std, 0) return Normal(0.0, std, 0)
filter = helper.create_parameter( filter_param = helper.create_parameter(
attr=helper.param_attr, attr=helper.param_attr,
shape=filter_shape, shape=filter_shape,
dtype=dtype, dtype=dtype,
...@@ -839,7 +841,7 @@ def conv2d(input, ...@@ -839,7 +841,7 @@ def conv2d(input,
type='conv2d_cudnn', type='conv2d_cudnn',
inputs={ inputs={
'Input': input, 'Input': input,
'Filter': filter, 'Filter': filter_param,
}, },
outputs={"Output": pre_bias}, outputs={"Output": pre_bias},
attrs={'strides': stride, attrs={'strides': stride,
...@@ -875,8 +877,8 @@ def sequence_pool(input, pool_type, **kwargs): ...@@ -875,8 +877,8 @@ def sequence_pool(input, pool_type, **kwargs):
def pool2d(input, def pool2d(input,
pool_size, pool_size,
pool_type, pool_type,
pool_stride=[1, 1], pool_stride=None,
pool_padding=[0, 0], pool_padding=None,
global_pooling=False, global_pooling=False,
main_program=None, main_program=None,
startup_program=None): startup_program=None):
...@@ -884,6 +886,10 @@ def pool2d(input, ...@@ -884,6 +886,10 @@ def pool2d(input,
This function adds the operator for pooling in 2 dimensions, using the This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters. pooling configurations mentioned in input parameters.
""" """
if pool_padding is None:
pool_padding = [0, 0]
if pool_stride is None:
pool_stride = [1, 1]
if pool_type not in ["max", "avg"]: if pool_type not in ["max", "avg"]:
raise ValueError( raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册