From 9b6c6837f85ca8b232237bb9001710d820e11b95 Mon Sep 17 00:00:00 2001 From: Meiyim Date: Thu, 18 Jun 2020 16:33:58 +0800 Subject: [PATCH] Patch (#499) * upgrade to the new lac interface * remove jieba * + init_checkppoint for seq_cls demo * backto jieba --- demo/finetune_classifier.py | 1 - demo/finetune_classifier_dygraph.py | 6 ++++++ demo/pretrain/make_pretrain_data.py | 1 - demo/pretrain/pretrain.py | 1 - demo/pretrain/pretrain_dygraph.py | 1 - ernie/tokenizing_ernie.py | 6 +++--- requirements.txt | 2 +- 7 files changed, 10 insertions(+), 8 deletions(-) diff --git a/demo/finetune_classifier.py b/demo/finetune_classifier.py index 3c18905..7f610e9 100644 --- a/demo/finetune_classifier.py +++ b/demo/finetune_classifier.py @@ -25,7 +25,6 @@ from functools import reduce, partial import numpy as np import multiprocessing import tempfile -import jieba import re import paddle diff --git a/demo/finetune_classifier_dygraph.py b/demo/finetune_classifier_dygraph.py index 53764aa..d82eade 100644 --- a/demo/finetune_classifier_dygraph.py +++ b/demo/finetune_classifier_dygraph.py @@ -58,6 +58,7 @@ if __name__ == '__main__': parser.add_argument('--save_dir', type=str, default=None, help='model output directory') parser.add_argument('--max_steps', type=int, default=None, help='max_train_steps, set this to EPOCH * NUM_SAMPLES / BATCH_SIZE') parser.add_argument('--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer') + parser.add_argument('--init_checkpoint', type=str, default=None, help='checkpoint to warm start from') args = parser.parse_args() @@ -103,6 +104,11 @@ if __name__ == '__main__': with FD.guard(place): model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='') + if args.init_checkpoint is not None: + log.info('loading checkpoint from %s' % args.init_checkpoint) + sd, _ = FD.load_dygraph(args.init_checkpoint) + model.set_dict(sd) + g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental if args.use_lr_decay: opt = AdamW(learning_rate=LinearDecay(args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd, grad_clip=g_clip) diff --git a/demo/pretrain/make_pretrain_data.py b/demo/pretrain/make_pretrain_data.py index 96be1d4..6f3d6ed 100644 --- a/demo/pretrain/make_pretrain_data.py +++ b/demo/pretrain/make_pretrain_data.py @@ -1,7 +1,6 @@ import sys import argparse import struct -#import jieba import random as r import re import gzip diff --git a/demo/pretrain/pretrain.py b/demo/pretrain/pretrain.py index 7d03e2a..d3ba1c9 100644 --- a/demo/pretrain/pretrain.py +++ b/demo/pretrain/pretrain.py @@ -44,7 +44,6 @@ from ernie.optimization import optimization import propeller.paddle as propeller from propeller.paddle.data import Dataset -#import jieba from propeller import log log.setLevel(logging.DEBUG) diff --git a/demo/pretrain/pretrain_dygraph.py b/demo/pretrain/pretrain_dygraph.py index f32e37a..3911ad9 100644 --- a/demo/pretrain/pretrain_dygraph.py +++ b/demo/pretrain/pretrain_dygraph.py @@ -45,7 +45,6 @@ from ernie.optimization import AdamW, LinearDecay import propeller.paddle as propeller from propeller.paddle.data import Dataset -#import jieba from propeller import log log.setLevel(logging.DEBUG) diff --git a/ernie/tokenizing_ernie.py b/ernie/tokenizing_ernie.py index 4556bb8..25ff1f5 100644 --- a/ernie/tokenizing_ernie.py +++ b/ernie/tokenizing_ernie.py @@ -222,14 +222,14 @@ class ErnieTinyTokenizer(ErnieTokenizer): def __init__(self, vocab, sp_model_path, **kwargs): super(ErnieTinyTokenizer, self).__init__(vocab, **kwargs) import sentencepiece as spm + import jieba as jb self.sp_model = spm.SentencePieceProcessor() self.window_size = 5 self.sp_model.Load(sp_model_path) - from LAC import LAC - self.lac = LAC() + self.jb = jb def cut(self, sentence): - return self.lac.lexer(sentence) + return self.jb.cut(sentence) def tokenize(self, text): if len(text) == 0: diff --git a/requirements.txt b/requirements.txt index 23fc00f..cacf671 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ pyzmq==18.0.2 six==1.11.0 sklearn==0.0 sentencepiece==0.1.8 -LAC +jieba==0.39 -- GitLab