未验证 提交 9b6c6837 编写于 作者: M Meiyim 提交者: GitHub

Patch (#499)

* upgrade to the new lac interface

* remove jieba

* + init_checkppoint for seq_cls demo

* backto jieba
上级 77d5d106
...@@ -25,7 +25,6 @@ from functools import reduce, partial ...@@ -25,7 +25,6 @@ from functools import reduce, partial
import numpy as np import numpy as np
import multiprocessing import multiprocessing
import tempfile import tempfile
import jieba
import re import re
import paddle import paddle
......
...@@ -58,6 +58,7 @@ if __name__ == '__main__': ...@@ -58,6 +58,7 @@ if __name__ == '__main__':
parser.add_argument('--save_dir', type=str, default=None, help='model output directory') parser.add_argument('--save_dir', type=str, default=None, help='model output directory')
parser.add_argument('--max_steps', type=int, default=None, help='max_train_steps, set this to EPOCH * NUM_SAMPLES / BATCH_SIZE') parser.add_argument('--max_steps', type=int, default=None, help='max_train_steps, set this to EPOCH * NUM_SAMPLES / BATCH_SIZE')
parser.add_argument('--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer') parser.add_argument('--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer')
parser.add_argument('--init_checkpoint', type=str, default=None, help='checkpoint to warm start from')
args = parser.parse_args() args = parser.parse_args()
...@@ -103,6 +104,11 @@ if __name__ == '__main__': ...@@ -103,6 +104,11 @@ if __name__ == '__main__':
with FD.guard(place): with FD.guard(place):
model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='') model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='')
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd, _ = FD.load_dygraph(args.init_checkpoint)
model.set_dict(sd)
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
if args.use_lr_decay: if args.use_lr_decay:
opt = AdamW(learning_rate=LinearDecay(args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd, grad_clip=g_clip) opt = AdamW(learning_rate=LinearDecay(args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd, grad_clip=g_clip)
......
import sys import sys
import argparse import argparse
import struct import struct
#import jieba
import random as r import random as r
import re import re
import gzip import gzip
......
...@@ -44,7 +44,6 @@ from ernie.optimization import optimization ...@@ -44,7 +44,6 @@ from ernie.optimization import optimization
import propeller.paddle as propeller import propeller.paddle as propeller
from propeller.paddle.data import Dataset from propeller.paddle.data import Dataset
#import jieba
from propeller import log from propeller import log
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
......
...@@ -45,7 +45,6 @@ from ernie.optimization import AdamW, LinearDecay ...@@ -45,7 +45,6 @@ from ernie.optimization import AdamW, LinearDecay
import propeller.paddle as propeller import propeller.paddle as propeller
from propeller.paddle.data import Dataset from propeller.paddle.data import Dataset
#import jieba
from propeller import log from propeller import log
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
......
...@@ -222,14 +222,14 @@ class ErnieTinyTokenizer(ErnieTokenizer): ...@@ -222,14 +222,14 @@ class ErnieTinyTokenizer(ErnieTokenizer):
def __init__(self, vocab, sp_model_path, **kwargs): def __init__(self, vocab, sp_model_path, **kwargs):
super(ErnieTinyTokenizer, self).__init__(vocab, **kwargs) super(ErnieTinyTokenizer, self).__init__(vocab, **kwargs)
import sentencepiece as spm import sentencepiece as spm
import jieba as jb
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
self.window_size = 5 self.window_size = 5
self.sp_model.Load(sp_model_path) self.sp_model.Load(sp_model_path)
from LAC import LAC self.jb = jb
self.lac = LAC()
def cut(self, sentence): def cut(self, sentence):
return self.lac.lexer(sentence) return self.jb.cut(sentence)
def tokenize(self, text): def tokenize(self, text):
if len(text) == 0: if len(text) == 0:
......
...@@ -3,4 +3,4 @@ pyzmq==18.0.2 ...@@ -3,4 +3,4 @@ pyzmq==18.0.2
six==1.11.0 six==1.11.0
sklearn==0.0 sklearn==0.0
sentencepiece==0.1.8 sentencepiece==0.1.8
LAC jieba==0.39
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册