提交 3eebc615 编写于 作者: S smallv0221

fix readme and change cross-entropy api

上级 bec2933a
......@@ -5,8 +5,6 @@
## 1. 任务说明
本文主要介绍基于lstm的语言的模型的实现,给定一个输入词序列(中文分词、英文tokenize),计算其ppl(语言模型困惑度,用户表示句子的流利程度),基于循环神经网络语言模型的介绍可以[参阅论文](https://arxiv.org/abs/1409.2329)。相对于传统的方法,基于循环神经网络的方法能够更好的解决稀疏词的问题。
**目前语言模型要求使用PaddlePaddle 2.0及以上版本或适当的develop版本。**
## 2. 效果说明
......@@ -27,6 +25,22 @@
## 1. 开始第一次模型调用
### 安装说明
* PaddlePaddle 安装
本项目依赖于 PaddlePaddle 2.0-rc1 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
* PaddleNLP 安装
```shell
pip install paddlenlp>=2.0.0b
```
* 环境依赖
Python的版本要求 3.6+
### 数据准备
为了方便开发者进行测试,我们内置了数据下载脚本,默认自动下载PTB数据集。
......
......@@ -77,8 +77,8 @@ class CrossEntropyLossForLm(nn.Layer):
def forward(self, y, label):
label = paddle.unsqueeze(label, axis=2)
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=y, label=label, soft_label=False)
loss = paddle.nn.functional.cross_entropy(
input=y, label=label, reduction='none')
loss = paddle.squeeze(loss, axis=[2])
loss = paddle.mean(loss, axis=[0])
loss = paddle.sum(loss)
......@@ -89,4 +89,3 @@ class UpdateModel(paddle.callbacks.Callback):
# This callback reset model hidden states and update learning rate before each epoch begins
def on_epoch_begin(self, epoch=None, logs=None):
self.model.network.reset_states()
......@@ -27,7 +27,7 @@ DuReader-robust数据集是单篇章、抽取式阅读理解数据集,具体
* PaddlePaddle 安装
本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
本项目依赖于 PaddlePaddle 2.0-rc1 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
* PaddleNLP 安装
......
......@@ -54,12 +54,10 @@ class CrossEntropyLossForSQuAD(paddle.nn.Layer):
start_position, end_position = label
start_position = paddle.unsqueeze(start_position, axis=-1)
end_position = paddle.unsqueeze(end_position, axis=-1)
start_loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=start_logits, label=start_position, soft_label=False)
start_loss = paddle.mean(start_loss)
end_loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=end_logits, label=end_position, soft_label=False)
end_loss = paddle.mean(end_loss)
start_loss = paddle.nn.functional.cross_entropy(
input=start_logits, label=start_position)
end_loss = paddle.nn.functional.cross_entropy(
input=end_logits, label=end_position)
loss = (start_loss + end_loss) / 2
return loss
......
......@@ -41,7 +41,7 @@
* PaddlePaddle 安装
本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
本项目依赖于 PaddlePaddle 2.0-rc1 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
* PaddleNLP 安装
......
......@@ -27,7 +27,7 @@ SQuAD v2.0
* PaddlePaddle 安装
本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
本项目依赖于 PaddlePaddle 2.0-rc1 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
* PaddleNLP 安装
......@@ -56,7 +56,7 @@ python -u ./run_squad.py \
--batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--logging_steps 1000 \
--logging_steps 100 \
--save_steps 1000 \
--warmup_proportion 0.1 \
--weight_decay 0.01 \
......
......@@ -51,12 +51,10 @@ class CrossEntropyLossForSQuAD(paddle.nn.Layer):
start_position, end_position = label
start_position = paddle.unsqueeze(start_position, axis=-1)
end_position = paddle.unsqueeze(end_position, axis=-1)
start_loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=start_logits, label=start_position, soft_label=False)
start_loss = paddle.mean(start_loss)
end_loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=end_logits, label=end_position, soft_label=False)
end_loss = paddle.mean(end_loss)
start_loss = paddle.nn.functional.cross_entropy(
input=start_logits, label=start_position)
end_loss = paddle.nn.functional.cross_entropy(
input=end_logits, label=end_position)
loss = (start_loss + end_loss) / 2
return loss
......
......@@ -378,8 +378,8 @@ def compute_f1(a_gold, a_pred, is_whitespace_splited=True):
pred_toks = normalize_answer(a_pred).split()
if not is_whitespace_splited:
gold_toks = gold_toks[0]
pred_toks = pred_toks[0]
gold_toks = gold_toks[0] if gold_toks else ""
pred_toks = pred_toks[0] if pred_toks else ""
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册