diff --git a/PaddleNLP/examples/language_model/bert/README.md b/PaddleNLP/examples/language_model/bert/README.md index b0ca74dd89da343cc521c4f601acae753b330c30..895552a65908ff9ab700c422950ff20319407804 100644 --- a/PaddleNLP/examples/language_model/bert/README.md +++ b/PaddleNLP/examples/language_model/bert/README.md @@ -138,3 +138,38 @@ python -u ./run_glue.py \ | QQP | Accuracy/F1 | 0.90581/0.87347 | | MNLI | Matched acc/MisMatched acc | 0.84422/0.84825 | | RTE | Accuracy | 0.711191 | + + +### 预测 + +在Fine-tuning完成后,我们可以使用如下方式导出希望用来预测的模型: + +```shell +python -u ./export_model.py \ + --model_type bert \ + --model_path bert-base-uncased \ + --output_path ./infer_model/model +``` + +其中参数释义如下: +- `model_type` 指示了模型类型,使用BERT模型时设置为bert即可。 +- `model_path` 表示训练模型的保存路径,与训练时的`output_dir`一致。 +- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams`,`pdiparams.info`,`pdmodel`);除此之外,还会在`model_path`包含的目录下保存tokenizer相关内容。 + +然后按照如下的方式进行GLUE中的评测任务进行预测(基于Paddle的[Python预测API](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/guides/05_inference_deployment/inference/python_infer_cn.html)): + +```shell +python -u ./predict_glue.py \ + --task_name SST-2 \ + --model_type bert \ + --model_path ./infer_model/model \ + --batch_size 32 \ + --max_seq_length 128 +``` + +其中参数释义如下: +- `task_name` 表示Fine-tuning的任务。 +- `model_type` 指示了模型类型,使用BERT模型时设置为bert即可。 +- `model_path` 表示预测模型文件的前缀,和上一步导出预测模型中的`output_path`一致。 +- `batch_size` 表示每个预测批次的样本数目。 +- `max_seq_length` 表示最大句子长度,超过该长度将被截断。 diff --git a/PaddleNLP/examples/language_model/bert/export_model.py b/PaddleNLP/examples/language_model/bert/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6fff7bbdbea4d159f4a4f74882abc64fcd452bb4 --- /dev/null +++ b/PaddleNLP/examples/language_model/bert/export_model.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle + +from run_glue import MODEL_CLASSES + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), + ) + parser.add_argument( + "--model_path", + default=None, + type=str, + required=True, + help="Path of the trained model to be exported.", + ) + parser.add_argument( + "--output_path", + default=None, + type=str, + required=True, + help= + "The output file prefix used to save the exported inference model.", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + # build model and load trained parameters + model = model_class.from_pretrained(args.model_path) + # switch to eval model + model.eval() + # convert to static graph with specific input description + model = paddle.jit.to_static( + model, + input_spec=[ + paddle.static.InputSpec(shape=[None, None], + dtype="int64"), # input_ids + paddle.static.InputSpec(shape=[None, None], + dtype="int64") # segment_ids + ]) + # save converted static graph model + paddle.jit.save(model, args.output_path) + # also save tokenizer for inference usage + tokenizer = tokenizer_class.from_pretrained(args.model_path) + tokenizer.save_pretrained(os.path.dirname(args.output_path)) + + +if __name__ == "__main__": + main() diff --git a/PaddleNLP/examples/language_model/bert/prdict_glue.py b/PaddleNLP/examples/language_model/bert/prdict_glue.py new file mode 100644 index 0000000000000000000000000000000000000000..946ca8d54a397052d730483df07aedb6e7322f40 --- /dev/null +++ b/PaddleNLP/examples/language_model/bert/prdict_glue.py @@ -0,0 +1,166 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from functools import partial + +import paddle +from paddle import inference +from paddlenlp.data import Stack, Tuple, Pad + +from run_glue import convert_example, TASK_CLASSES, MODEL_CLASSES + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default=None, + type=str, + required=True, + help="The name of the task to perform predict, selected in the list: " + + ", ".join(TASK_CLASSES.keys()), + ) + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), + ) + parser.add_argument( + "--model_path", + default=None, + type=str, + required=True, + help="The path prefix of inference model to be used.", + ) + parser.add_argument( + "--select_device", + default="gpu", + choices=["gpu", "cpu", "xpu"], + help="Device selected for inference.", + ) + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size for predict.", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help= + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + args = parser.parse_args() + return args + + +class Predictor(object): + def __init__(self, predictor, input_handles, output_handles): + self.predictor = predictor + self.input_handles = input_handles + self.output_handles = output_handles + + @classmethod + def create_predictor(cls, args): + config = paddle.inference.Config(args.model_path + ".pdmodel", + args.model_path + ".pdiparams") + if args.select_device == "gpu": + # set GPU configs accordingly + config.enable_use_gpu(100, 0) + elif args.select_device == "cpu": + # set CPU configs accordingly, + # such as enable_mkldnn, set_cpu_math_library_num_threads + config.disable_gpu() + elif args.select_device == "xpu": + # set XPU configs accordingly + config.enable_xpu(100) + config.switch_use_feed_fetch_ops(False) + predictor = paddle.inference.create_predictor(config) + input_handles = [ + predictor.get_input_handle(name) + for name in predictor.get_input_names() + ] + output_handles = [ + predictor.get_input_handle(name) + for name in predictor.get_output_names() + ] + return cls(predictor, input_handles, output_handles) + + def predict_batch(self, data): + for input_field, input_handle in zip(data, self.input_handles): + input_handle.copy_from_cpu(input_field.numpy( + ) if isinstance(input_field, paddle.Tensor) else input_field) + self.predictor.run() + output = [ + output_handle.copy_to_cpu() for output_handle in self.output_handles + ] + return output + + def predict(self, dataset, collate_fn, batch_size=1): + batch_sampler = paddle.io.BatchSampler(dataset, + batch_size=batch_size, + shuffle=False) + data_loader = paddle.io.DataLoader(dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=0, + return_list=True) + outputs = [] + for data in data_loader: + output = self.predict_batch(data) + outputs.append(output) + return outputs + + +def main(): + args = parse_args() + + predictor = Predictor.create_predictor(args) + + args.task_name = args.task_name.lower() + dataset_class, metric_class = TASK_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + dataset = dataset_class.get_datasets("test") + tokenizer = tokenizer_class.from_pretrained(os.path.dirname( + args.model_path)) + transform_fn = partial(convert_example, + tokenizer=tokenizer, + label_list=dataset.get_labels(), + max_seq_length=args.max_seq_length, + is_test=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment + Stack(), # length + ): [data for i, data in enumerate(fn(samples)) if i != 2] + dataset = dataset.apply(transform_fn) + + predictor.predict(dataset, + batch_size=args.batch_size, + collate_fn=batchify_fn) + + +if __name__ == "__main__": + main()