提交 03d32d61 编写于 作者: W weishengyu

dbg

上级 29219340
......@@ -41,7 +41,7 @@ class ExportModel(nn.Layer):
self.infer_output_key = config.get("infer_output_key")
if self.infer_output_key == "features" and isinstance(self.base_model,
RecModel):
self.base_model.neck = Identity()
self.base_model.head = IdentityHead()
if config.get("infer_add_softmax", True):
self.softmax = nn.Softmax(axis=-1)
else:
......@@ -62,6 +62,14 @@ class ExportModel(nn.Layer):
return x
class IdentityHead(nn.Layer):
def __init__(self):
super(IdentityHead, self).__init__()
def forward(self, x, label):
return {"features": x, "logits": None}
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册