# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os from functools import partial import numpy as np import paddle import paddle.nn.functional as F import paddlenlp as ppnlp from paddlenlp.data import Stack, Tuple, Pad from base_model import SemanticIndexBase, SemanticIndexBaseStatic # yapf: disable parser = argparse.ArgumentParser() parser.add_argument("--params_path", type=str, required=True, default='./checkpoint/model_900/model_state.pdparams', help="The path to model parameters to be loaded.") parser.add_argument("--output_path", type=str, default='./output', help="The path of model parameter in static graph to be saved.") args = parser.parse_args() # yapf: enable if __name__ == "__main__": # If you want to use ernie1.0 model, plesace uncomment the following code output_emb_size = 256 pretrained_model = ppnlp.transformers.ErnieModel.from_pretrained( "ernie-1.0") tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0') model = SemanticIndexBaseStatic( pretrained_model, output_emb_size=output_emb_size) if args.params_path and os.path.isfile(args.params_path): state_dict = paddle.load(args.params_path) model.set_dict(state_dict) print("Loaded parameters from %s" % args.params_path) model.eval() # Convert to static graph with specific input description model = paddle.jit.to_static( model, input_spec=[ paddle.static.InputSpec( shape=[None, None], dtype="int64"), # input_ids paddle.static.InputSpec( shape=[None, None], dtype="int64") # segment_ids ]) # Save in static graph model. save_path = os.path.join(args.output_path, "inference") paddle.jit.save(model, save_path)