未验证 提交 16c6580d 编写于 作者: K kinghuin 提交者: GitHub

Fix ernie_gen and CRF bug (#5057)

* add ernie_gen

* optimize ernie_gen

* optimize ernie_gen

* optimize ernie_gen code

* fix crf bug

* add ernie_gen __init__.py

* modify nlp version

* fix ernie_gen predict
上级 885432bf
...@@ -129,7 +129,6 @@ def predict(): ...@@ -129,7 +129,6 @@ def predict():
map(post_process, vocab.to_tokens(predict_ids))) map(post_process, vocab.to_tokens(predict_ids)))
print("source :%s\ntarget :%s\npredict:%s\n" % print("source :%s\ntarget :%s\npredict:%s\n" %
(source_sentence, tgt_sentence, predict_ids)) (source_sentence, tgt_sentence, predict_ids))
break
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = '2.0.0a6' __version__ = '2.0.0a7'
from . import data from . import data
from . import datasets from . import datasets
......
...@@ -230,10 +230,10 @@ class LinearChainCrf(nn.Layer): ...@@ -230,10 +230,10 @@ class LinearChainCrf(nn.Layer):
return self._seq_index[:length] return self._seq_index[:length]
def _get_batch_seq_index(self, batch_size, length): def _get_batch_seq_index(self, batch_size, length):
if self._batch_seq_index is None or length > self._batch_seq_index.shape[1]: if self._batch_seq_index is None or length + 2 > self._batch_seq_index.shape[
1]:
self._batch_seq_index = paddle.cumsum( self._batch_seq_index = paddle.cumsum(
paddle.ones([batch_size, length + 2], "int64"), paddle.ones([batch_size, length + 2], "int64"), axis=1) - 1
axis=1) - 1
if self.with_start_stop_tag: if self.with_start_stop_tag:
return self._batch_seq_index[:, :length + 2] return self._batch_seq_index[:, :length + 2]
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册