提交 93e0609f 编写于 作者: C chengduoZH

fix bug

上级 d4dabe3e
...@@ -1262,7 +1262,7 @@ def conv2d(input, ...@@ -1262,7 +1262,7 @@ def conv2d(input,
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
input_shape = input.shape input_shape = input.shape
filter_shape = [num_filters, num_filter_channels] + filter_size filter_shape = [num_filters, num_filter_channels] + list(filter_size)
def _get_default_param_initializer(): def _get_default_param_initializer():
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5 std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
...@@ -1826,7 +1826,7 @@ def conv2d_transpose(input, ...@@ -1826,7 +1826,7 @@ def conv2d_transpose(input,
elif isinstance(filter_size, int): elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size] filter_size = [filter_size, filter_size]
filter_shape = [input_channel, num_filters] + filter_size filter_shape = [input_channel, num_filters] + list(filter_size)
img_filter = helper.create_parameter( img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册