未验证 提交 85105600 编写于 作者: S Steffy-zxf 提交者: GitHub

update docs (#5180)

* update codes for senta benchmark
上级 cb760565
......@@ -129,7 +129,7 @@ def predict(model, data, tokenizer, label_map, batch_size=1):
examples = []
for text in data:
input_ids, segment_ids = convert_example(
[text],
text,
tokenizer,
label_list=label_map.values(),
max_seq_length=args.max_seq_length,
......
......@@ -153,13 +153,13 @@ wget https://paddlenlp.bj.bcebos.com/data/senta_word_dict.txt
CPU 启动:
```shell
python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=False --network=bilstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints'
python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=False --network=bilstm --lr=5e-4 --batch_size=64 --epochs=10 --save_dir='./checkpoints'
```
GPU 启动:
```shell
# CUDA_VISIBLE_DEVICES=0 python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=True --network=bilstm --lr=5e-4 --batch_size=64 --epochs=5 --save_dir='./checkpoints'
CUDA_VISIBLE_DEVICES=0 python train.py --vocab_path='./senta_word_dict.txt' --use_gpu=True --network=bilstm --lr=5e-4 --batch_size=64 --epochs=10 --save_dir='./checkpoints'
```
以上参数表示:
......
......@@ -160,10 +160,12 @@ if __name__ == "__main__":
print("Loaded checkpoint from %s" % args.init_from_ckpt)
# Starts training and evaluating.
callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3)
model.fit(train_loader,
dev_loader,
epochs=args.epochs,
save_dir=args.save_dir)
save_dir=args.save_dir,
callbacks=callback)
# Finally tests model.
results = model.evaluate(test_loader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册