提交 5eecd55a 编写于 作者: A Amir Alavi 提交者: François Chollet

load_weights will fail if shape mismatch (#10266)

Fix for #10265
上级 d476ecc7
......@@ -987,16 +987,24 @@ def load_weights_from_hdf5_group_by_name(f, layers, skip_mismatch=False,
' element(s).')
# Set values.
for i in range(len(weight_values)):
if skip_mismatch:
if K.int_shape(symbolic_weights[i]) != weight_values[i].shape:
if K.int_shape(symbolic_weights[i]) != weight_values[i].shape:
if skip_mismatch:
warnings.warn('Skipping loading of weights for layer {}'.format(layer.name) +
' due to mismatch in shape' +
' ({} vs {}).'.format(
symbolic_weights[i].shape,
weight_values[i].shape))
continue
weight_value_tuples.append((symbolic_weights[i],
weight_values[i]))
else:
raise ValueError('Layer #' + str(k) +
' (named "' + layer.name +
'"), weight ' +
str(symbolic_weights[i]) +
' has shape {}'.format(K.int_shape(symbolic_weights[i])) +
', but the saved weight has shape ' +
str(weight_values[i].shape) + '.')
else:
weight_value_tuples.append((symbolic_weights[i],
weight_values[i]))
K.batch_set_value(weight_value_tuples)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册