未验证 提交 73bc3db8 编写于 作者: W WJJ1995 提交者: GitHub

fixed linear fuse (#800)

上级 4eb7510d
......@@ -118,7 +118,9 @@ class TraceFcFuser(FuseBase):
(1, 0))
self.rm_params.add(weight_name)
bias_numpy = parameters[bias_name]
parameters["{}.bias".format(linear_name)] = np.squeeze(bias_numpy)
if len(bias_numpy.shape) == 2:
bias_numpy = np.squeeze(bias_numpy)
parameters["{}.bias".format(linear_name)] = bias_numpy
self.rm_params.add(bias_name)
new_layer = PaddleLayer(
layers_id[0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册