未验证 提交 16c6580d 编写于 作者: K kinghuin 提交者: GitHub

Fix ernie_gen and CRF bug (#5057)

* add ernie_gen

* optimize ernie_gen

* optimize ernie_gen

* optimize ernie_gen code

* fix crf bug

* add ernie_gen __init__.py

* modify nlp version

* fix ernie_gen predict
上级 885432bf
......@@ -129,7 +129,6 @@ def predict():
map(post_process, vocab.to_tokens(predict_ids)))
print("source :%s\ntarget :%s\npredict:%s\n" %
(source_sentence, tgt_sentence, predict_ids))
break
if __name__ == "__main__":
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = '2.0.0a6'
__version__ = '2.0.0a7'
from . import data
from . import datasets
......
......@@ -230,10 +230,10 @@ class LinearChainCrf(nn.Layer):
return self._seq_index[:length]
def _get_batch_seq_index(self, batch_size, length):
if self._batch_seq_index is None or length > self._batch_seq_index.shape[1]:
if self._batch_seq_index is None or length + 2 > self._batch_seq_index.shape[
1]:
self._batch_seq_index = paddle.cumsum(
paddle.ones([batch_size, length + 2], "int64"),
axis=1) - 1
paddle.ones([batch_size, length + 2], "int64"), axis=1) - 1
if self.with_start_stop_tag:
return self._batch_seq_index[:, :length + 2]
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册