diff --git a/tools/export_model.py b/tools/export_model.py index 86e84eaaa7c294722208e096fd3b14ab6b20f1ae..dd134b3cb4ee8c2c32534dcff9dace072325babf 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -29,7 +29,7 @@ from ppcls.arch import build_model from ppcls.utils.save_load import load_dygraph_pretrain -class ClasModel(nn.Layer): +class ExportModel(nn.Layer): """ ClasModel: add softmax onto the model """ @@ -37,7 +37,11 @@ class ClasModel(nn.Layer): def __init__(self, config): super().__init__() self.base_model = build_model(config) - self.softmax = nn.Softmax(axis=-1) + self.infer_output_key = config.get("infer_output_key") + if config.get("infer_add_softmax", False): + self.softmax = nn.Softmax(axis=-1) + else: + self.softmax = None def eval(self): self.training = False @@ -47,7 +51,10 @@ class ClasModel(nn.Layer): def forward(self, x): x = self.base_model(x) - x = self.softmax(x) + if self.infer_output_key is not None: + x = x[self.infer_output_key] + if self.softmax is not None: + x = self.softmax(x) return x @@ -57,8 +64,7 @@ if __name__ == "__main__": # set device assert config["Global"]["device"] in ["cpu", "gpu", "xpu"] device = paddle.set_device(config["Global"]["device"]) - - model = ClasModel(config["Arch"]) + model = ExportModel(config["Arch"]) if config["Global"]["pretrained_model"] is not None: load_dygraph_pretrain(model.base_model,