提交 3056bb86 编写于 作者: L lishuang

fix the get_input_mask function

judge the layer name rather than idx
上级 7aae9b81
......@@ -182,7 +182,7 @@ def get_input_mask(module_defs, idx, CBLidx2mask):
if idx == 0:
return np.ones(3)
if idx == 1:
if module_defs[idx - 1]['type'] == 'focus':
return np.ones(12)
if module_defs[idx - 1]['type'] == 'convolutional':
return CBLidx2mask[idx - 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册