You need to sign in or sign up before continuing.
提交 7573205c 编写于 作者: L Luo Tao

follow comments on config_parser

上级 96615fe3
...@@ -78,10 +78,10 @@ message ConvConfig { ...@@ -78,10 +78,10 @@ message ConvConfig {
required uint32 stride_y = 12; required uint32 stride_y = 12;
// if not set, use output_x // if not set, use output_x
optional uint32 output_y = 13 [default = 0]; optional uint32 output_y = 13;
// if not set, use img_size // if not set, use img_size
optional uint32 img_size_y = 14 [default = 0]; optional uint32 img_size_y = 14;
} }
message PoolConfig { message PoolConfig {
...@@ -161,10 +161,10 @@ message NormConfig { ...@@ -161,10 +161,10 @@ message NormConfig {
optional bool blocked = 8; optional bool blocked = 8;
// if not set, use output_x // if not set, use output_x
optional uint32 output_y = 9 [default = 0]; optional uint32 output_y = 9;
// if not set, use img_size // if not set, use img_size
optional uint32 img_size_y = 10 [default = 0]; optional uint32 img_size_y = 10;
} }
message BlockExpandConfig { message BlockExpandConfig {
......
...@@ -1066,7 +1066,7 @@ def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode): ...@@ -1066,7 +1066,7 @@ def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode):
return 1 + int(math.ceil(output)) return 1 + int(math.ceil(output))
#calcualte image_size based on output_size for convolution. #calcualte image_size based on output_size for de-convolution (ConvTransLayer).
#It is the reverse function of cnn_output_size #It is the reverse function of cnn_output_size
def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode): def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode):
img_size = (output_size - 1) * stride + filter_size - 2 * padding img_size = (output_size - 1) * stride + filter_size - 2 * padding
...@@ -1075,7 +1075,7 @@ def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode): ...@@ -1075,7 +1075,7 @@ def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode):
return img_size return img_size
def set_img_size(input_layer_name, channels): def get_img_size(input_layer_name, channels):
input = g_layer_map[input_layer_name] input = g_layer_map[input_layer_name]
img_pixels = input.size / channels img_pixels = input.size / channels
img_size = input.width if input.width > 0 else int(img_pixels**0.5) img_size = input.width if input.width > 0 else int(img_pixels**0.5)
...@@ -1110,7 +1110,7 @@ def parse_pool(pool, input_layer_name, pool_conf): ...@@ -1110,7 +1110,7 @@ def parse_pool(pool, input_layer_name, pool_conf):
pool_conf.stride_y = default(pool.stride_y, pool_conf.stride) pool_conf.stride_y = default(pool.stride_y, pool_conf.stride)
pool_conf.img_size, pool_conf.img_size_y = \ pool_conf.img_size, pool_conf.img_size_y = \
set_img_size(input_layer_name, pool.channels) get_img_size(input_layer_name, pool.channels)
config_assert(not pool.start, "start is deprecated in pooling.") config_assert(not pool.start, "start is deprecated in pooling.")
...@@ -1137,7 +1137,7 @@ def parse_spp(spp, input_layer_name, spp_conf): ...@@ -1137,7 +1137,7 @@ def parse_spp(spp, input_layer_name, spp_conf):
def parse_image(image, input_layer_name, image_conf): def parse_image(image, input_layer_name, image_conf):
image_conf.channels = image.channels image_conf.channels = image.channels
image_conf.img_size, image_conf.img_size_y = \ image_conf.img_size, image_conf.img_size_y = \
set_img_size(input_layer_name, image_conf.channels) get_img_size(input_layer_name, image_conf.channels)
def parse_norm(norm, input_layer_name, norm_conf): def parse_norm(norm, input_layer_name, norm_conf):
...@@ -1152,7 +1152,7 @@ def parse_norm(norm, input_layer_name, norm_conf): ...@@ -1152,7 +1152,7 @@ def parse_norm(norm, input_layer_name, norm_conf):
norm_conf.blocked = norm.blocked norm_conf.blocked = norm.blocked
norm_conf.img_size, norm_conf.img_size_y = \ norm_conf.img_size, norm_conf.img_size_y = \
set_img_size(input_layer_name, norm.channels) get_img_size(input_layer_name, norm.channels)
norm_conf.output_x = norm_conf.img_size norm_conf.output_x = norm_conf.img_size
norm_conf.output_y = norm_conf.img_size_y norm_conf.output_y = norm_conf.img_size_y
if norm.norm_type in ['cmrnorm-projection']: if norm.norm_type in ['cmrnorm-projection']:
...@@ -1177,7 +1177,7 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False): ...@@ -1177,7 +1177,7 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
if not trans: if not trans:
conv_conf.filter_channels = conv.channels / conv.groups conv_conf.filter_channels = conv.channels / conv.groups
conv_conf.img_size, conv_conf.img_size_y = \ conv_conf.img_size, conv_conf.img_size_y = \
set_img_size(input_layer_name, conv.channels) get_img_size(input_layer_name, conv.channels)
conv_conf.output_x = cnn_output_size( conv_conf.output_x = cnn_output_size(
conv_conf.img_size, conv_conf.filter_size, conv_conf.padding, conv_conf.img_size, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode) conv_conf.stride, conv_conf.caffe_mode)
...@@ -1187,11 +1187,11 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False): ...@@ -1187,11 +1187,11 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
else: else:
conv_conf.filter_channels = num_filters / conv.groups conv_conf.filter_channels = num_filters / conv.groups
conv_conf.output_x, conv_conf.output_y = \ conv_conf.output_x, conv_conf.output_y = \
set_img_size(input_layer_name, conv.channels) get_img_size(input_layer_name, conv.channels)
conv_conf.img_size = cnn_image_size( conv_conf.img_size = cnn_image_size(
conv_conf.output_x, conv_conf.filter_size, conv_conf.padding, conv_conf.output_x, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode) conv_conf.stride, conv_conf.caffe_mode)
conv_conf.img_size_y = cnn_output_size( conv_conf.img_size_y = cnn_image_size(
conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y, conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y,
conv_conf.stride_y, conv_conf.caffe_mode) conv_conf.stride_y, conv_conf.caffe_mode)
......
...@@ -27,7 +27,7 @@ layers { ...@@ -27,7 +27,7 @@ layers {
padding_y: 1 padding_y: 1
stride_y: 1 stride_y: 1
output_y: 227 output_y: 227
img_size_y: 198 img_size_y: 256
} }
} }
bias_parameter_name: "___conv_0__.wbias" bias_parameter_name: "___conv_0__.wbias"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册