提交 952726a4 编写于 作者: W wjj19950828

fixed fc fuser

上级 081e773b
...@@ -113,10 +113,11 @@ class TraceFcFuser(FuseBase): ...@@ -113,10 +113,11 @@ class TraceFcFuser(FuseBase):
attrs["out_features"] = parameters[weight_name].shape[0] attrs["out_features"] = parameters[weight_name].shape[0]
linear_name = "linear{}".format(self.linear_index) linear_name = "linear{}".format(self.linear_index)
self.linear_index += 1 self.linear_index += 1
parameters["{}.weight".format(linear_name)] = parameters[ weight_numpy = parameters.pop(weight_name)
weight_name].transpose((1, 0)) parameters["{}.weight".format(linear_name)] = weight_numpy.transpose(
parameters["{}.bias".format(linear_name)] = np.squeeze(parameters[ (1, 0))
bias_name]) bias_numpy = parameters.pop(bias_name)
parameters["{}.bias".format(linear_name)] = np.squeeze(bias_numpy)
new_layer = PaddleLayer( new_layer = PaddleLayer(
layers_id[0], layers_id[0],
"paddle.nn.Linear", "paddle.nn.Linear",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册