dygraph.load_persistables的参数数量和model.state_dict()数量不一致
Created by: Exception-star
dygraph.load_persistables的参数数量 与 model.state_dict()的数量不一致

问题描述:按理说数量应该是一致的,打印state.dict,发现没有conv的参数,只有bn的参数。conv和BN都在一个类里面。
class ConvBN(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_filters,
filter_size=3,
stride=1,
dilation=1,
act=None,
dtype='float32',
bias_attr=False):
super(ConvBN, self).__init__(name_scope)
if dilation != 1:
padding = dilation
else:
padding = (filter_size - 1) // 2
self._conv = fluid.dygraph.Conv2D(name_scope,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
act=None,
dtype=dtype,
bias_attr=bias_attr if bias_attr is False else fluid.ParamAttr(
name='bias'),
param_attr=fluid.ParamAttr(
name='weight')
)
self._bn = fluid.dygraph.BatchNorm(name_scope,
num_channels=num_filters,
act=act,
dtype=dtype,
momentum=0.9,
epsilon=1e-5,
bias_attr=fluid.ParamAttr(
name='bias'),
param_attr=fluid.ParamAttr(
name='weight'),
moving_mean_name='running_mean',
moving_variance_name='running_var'
)
def forward(self, inputs):
x = self._conv(inputs)
x = self._bn(x)
return x