提交 33032b12 编写于 作者: X xzl

fix bug: regenrate test proto of img_conv

上级 300b5094
...@@ -874,7 +874,7 @@ class Conv(Cfg): ...@@ -874,7 +874,7 @@ class Conv(Cfg):
filter_size_y=None, filter_size_y=None,
padding_y=None, padding_y=None,
stride_y=None, stride_y=None,
dilation=1, dilation=None,
dilation_y=None): dilation_y=None):
self.add_keys(locals()) self.add_keys(locals())
if filter_size_y is None: if filter_size_y is None:
...@@ -1388,6 +1388,10 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False): ...@@ -1388,6 +1388,10 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
conv_conf.stride_y = conv.stride_y conv_conf.stride_y = conv.stride_y
conv_conf.groups = conv.groups conv_conf.groups = conv.groups
conv_conf.caffe_mode = conv.caffe_mode conv_conf.caffe_mode = conv.caffe_mode
if not conv.dilation:
conv.dilation = 1
conv.dilation_y = 1
else:
conv_conf.dilation = conv.dilation conv_conf.dilation = conv.dilation
conv_conf.dilation_y = conv.dilation_y conv_conf.dilation_y = conv.dilation_y
...@@ -1397,20 +1401,20 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False): ...@@ -1397,20 +1401,20 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
get_img_size(input_layer_name, conv.channels) get_img_size(input_layer_name, conv.channels)
conv_conf.output_x = cnn_output_size( conv_conf.output_x = cnn_output_size(
conv_conf.img_size, conv_conf.filter_size, conv_conf.padding, conv_conf.img_size, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode, conv_conf.dilation) conv_conf.stride, conv_conf.caffe_mode, conv.dilation)
conv_conf.output_y = cnn_output_size( conv_conf.output_y = cnn_output_size(
conv_conf.img_size_y, conv_conf.filter_size_y, conv_conf.padding_y, conv_conf.img_size_y, conv_conf.filter_size_y, conv_conf.padding_y,
conv_conf.stride_y, conv_conf.caffe_mode, conv_conf.dilation_y) conv_conf.stride_y, conv_conf.caffe_mode, conv.dilation_y)
else: else:
conv_conf.filter_channels = num_filters / conv.groups conv_conf.filter_channels = num_filters / conv.groups
conv_conf.output_x, conv_conf.output_y = \ conv_conf.output_x, conv_conf.output_y = \
get_img_size(input_layer_name, conv.channels) get_img_size(input_layer_name, conv.channels)
conv_conf.img_size = cnn_image_size( conv_conf.img_size = cnn_image_size(
conv_conf.output_x, conv_conf.filter_size, conv_conf.padding, conv_conf.output_x, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode, conv_conf.dilation) conv_conf.stride, conv_conf.caffe_mode, conv.dilation)
conv_conf.img_size_y = cnn_image_size( conv_conf.img_size_y = cnn_image_size(
conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y, conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y,
conv_conf.stride_y, conv_conf.caffe_mode, conv_conf.dilation_y) conv_conf.stride_y, conv_conf.caffe_mode, conv.dilation_y)
#caffe_mode: compute the output size using floor instead of ceil, #caffe_mode: compute the output size using floor instead of ceil,
......
...@@ -2523,7 +2523,9 @@ def img_conv_layer(input, ...@@ -2523,7 +2523,9 @@ def img_conv_layer(input,
if layer_type: if layer_type:
if dilation > 1 or dilation_y > 1: if dilation > 1 or dilation_y > 1:
assert layer_type in ["cudnn_conv", "cudnn_convt"] assert layer_type in [
"cudnn_conv", "cudnn_convt", "exconv", "exconvt"
]
if trans: if trans:
assert layer_type in ["exconvt", "cudnn_convt"] assert layer_type in ["exconvt", "cudnn_convt"]
else: else:
......
...@@ -28,6 +28,8 @@ layers { ...@@ -28,6 +28,8 @@ layers {
stride_y: 1 stride_y: 1
output_y: 227 output_y: 227
img_size_y: 256 img_size_y: 256
dilation: 1
dilation_y: 1
} }
} }
bias_parameter_name: "___conv_0__.wbias" bias_parameter_name: "___conv_0__.wbias"
......
...@@ -28,6 +28,8 @@ layers { ...@@ -28,6 +28,8 @@ layers {
stride_y: 1 stride_y: 1
output_y: 227 output_y: 227
img_size_y: 256 img_size_y: 256
dilation: 1
dilation_y: 1
} }
} }
bias_parameter_name: "___conv_0__.wbias" bias_parameter_name: "___conv_0__.wbias"
......
...@@ -28,6 +28,8 @@ layers { ...@@ -28,6 +28,8 @@ layers {
stride_y: 1 stride_y: 1
output_y: 48 output_y: 48
img_size_y: 48 img_size_y: 48
dilation: 1
dilation_y: 1
} }
} }
bias_parameter_name: "___conv_0__.wbias" bias_parameter_name: "___conv_0__.wbias"
......
...@@ -30,6 +30,8 @@ layers { ...@@ -30,6 +30,8 @@ layers {
stride_y: 1 stride_y: 1
output_y: 48 output_y: 48
img_size_y: 48 img_size_y: 48
dilation: 1
dilation_y: 1
} }
} }
bias_parameter_name: "___conv_0__.wbias" bias_parameter_name: "___conv_0__.wbias"
...@@ -105,6 +107,8 @@ layers { ...@@ -105,6 +107,8 @@ layers {
stride_y: 1 stride_y: 1
output_y: 24 output_y: 24
img_size_y: 24 img_size_y: 24
dilation: 1
dilation_y: 1
} }
} }
bias_parameter_name: "___conv_1__.wbias" bias_parameter_name: "___conv_1__.wbias"
......
...@@ -30,6 +30,8 @@ layers { ...@@ -30,6 +30,8 @@ layers {
stride_y: 1 stride_y: 1
output_y: 48 output_y: 48
img_size_y: 48 img_size_y: 48
dilation: 1
dilation_y: 1
} }
} }
bias_parameter_name: "___conv_0__.wbias" bias_parameter_name: "___conv_0__.wbias"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册