未验证 提交 6abaa0c8 编写于 作者: L littletomatodonkey 提交者: GitHub

fix dpn (#416)

上级 463980f3
...@@ -209,7 +209,7 @@ class DualPathFactory(nn.Layer): ...@@ -209,7 +209,7 @@ class DualPathFactory(nn.Layer):
class DPN(nn.Layer): class DPN(nn.Layer):
def __init__(self, layers=60, class_dim=1000): def __init__(self, layers=68, class_dim=1000):
super(DPN, self).__init__() super(DPN, self).__init__()
self._class_dim = class_dim self._class_dim = class_dim
...@@ -230,7 +230,7 @@ class DPN(nn.Layer): ...@@ -230,7 +230,7 @@ class DPN(nn.Layer):
self.conv1_x_1_func = ConvBNLayer( self.conv1_x_1_func = ConvBNLayer(
num_channels=3, num_channels=3,
num_filters=init_num_filter, num_filters=init_num_filter,
filter_size=3, filter_size=init_filter_size,
stride=2, stride=2,
pad=1, pad=1,
act='relu', act='relu',
......
...@@ -57,7 +57,7 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False): ...@@ -57,7 +57,7 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
for key in model_dict.keys(): for key in model_dict.keys():
weight_name = model_dict[key].name weight_name = model_dict[key].name
if weight_name in pre_state_dict.keys(): if weight_name in pre_state_dict.keys():
print('Load weight: {}, shape: {}'.format( logger.info('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape)) weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name] param_state_dict[key] = pre_state_dict[weight_name]
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册