提交 0d096855 编写于 作者: M Megvii Engine Team

fix(imperative/quantization): add conditional when passing parameters name

GitOrigin-RevId: f38042aeabd653cbc23a744de581b550135c0a86
上级 652ec9f2
...@@ -40,7 +40,9 @@ class _ConvBnActivation2d(Conv2d): ...@@ -40,7 +40,9 @@ class _ConvBnActivation2d(Conv2d):
) )
weight = w_fold.astype(qat_module.get_weight_dtype()) weight = w_fold.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy(), name=qat_module.conv.weight.name) qconv.weight = Parameter(weight.numpy(), name=qat_module.conv.weight.name)
qconv.bias = Parameter(b_fold.numpy(), name=qat_module.conv.bias.name) qconv.bias = Parameter(b_fold.numpy())
if qat_module.conv.bias is not None:
qconv.bias.name = qat_module.conv.bias.name
return qconv return qconv
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册