未验证 提交 f6d9ea33 编写于 作者: Rooooooooc's avatar Rooooooooc 提交者: GitHub

Merge pull request #16 from becauseofAI/patch-3

fix the pad bug
......@@ -185,7 +185,7 @@ def cfg2prototxt(cfgfile):
prev_filters = block['filters']
convolution_param['kernel_size'] = block['size']
if block['pad'] == '1':
convolution_param['pad'] = str(int(convolution_param['kernel_size'])/2)
convolution_param['pad'] = str(int(convolution_param['kernel_size']) // 2)
convolution_param['stride'] = block['stride']
if block['batch_normalize'] == '1':
convolution_param['bias_term'] = 'false'
......@@ -254,7 +254,7 @@ def cfg2prototxt(cfgfile):
convolution_param['num_output'] = prev_filters
convolution_param['kernel_size'] = block['size']
if block['pad'] == '1':
convolution_param['pad'] = str(int(convolution_param['kernel_size'])/2)
convolution_param['pad'] = str(int(convolution_param['kernel_size']) // 2)
convolution_param['stride'] = block['stride']
if block['batch_normalize'] == '1':
convolution_param['bias_term'] = 'false'
......@@ -320,15 +320,15 @@ def cfg2prototxt(cfgfile):
pooling_param = OrderedDict()
pooling_param['stride'] = block['stride']
pooling_param['pool'] = 'MAX'
pooling_param['kernel_size'] = block['size']
pooling_param['pad'] = str((int(block['size'])-1)/2)
#if (int(block['size']) - int(block['stride'])) % 2 == 0:
# pooling_param['kernel_size'] = block['size']
# pooling_param['pad'] = str((int(block['size'])-1)/2)
#if (int(block['size']) - int(block['stride'])) % 2 == 1:
# pooling_param['kernel_size'] = str(int(block['size']) + 1)
# pooling_param['pad'] = str((int(block['size']) + 1)/2)
# pooling_param['kernel_size'] = block['size']
# pooling_param['pad'] = str((int(block['size'])-1) // 2)
if (int(block['size']) - int(block['stride'])) % 2 == 0:
pooling_param['kernel_size'] = block['size']
pooling_param['pad'] = str((int(block['size'])-1) // 2)
if (int(block['size']) - int(block['stride'])) % 2 == 1:
pooling_param['kernel_size'] = str(int(block['size']) + 1)
pooling_param['pad'] = str((int(block['size']) + 1) // 2)
max_layer['pooling_param'] = pooling_param
layers.append(max_layer)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册