未验证 提交 a2b70aa6 编写于 作者: M Meiyim 提交者: GitHub

Fix for pd22 (#763)

* fix-load-pretrained model

* update readme

* fix path
上级 ff89c2a6
...@@ -173,7 +173,7 @@ data/xnli ...@@ -173,7 +173,7 @@ data/xnli
- 使用 `动态图` 模型进行finetune: - 使用 `动态图` 模型进行finetune:
```script ```script
python3 ./ernie_d/demo/finetune_classifier.py \ python3 ./demo/finetune_classifier.py \
--from_pretrained ernie-1.0 \ --from_pretrained ernie-1.0 \
--data_dir ./data/xnli --data_dir ./data/xnli
``` ```
......
...@@ -153,7 +153,7 @@ if not os.path.exists('./teacher_model.bin'): ...@@ -153,7 +153,7 @@ if not os.path.exists('./teacher_model.bin'):
if step % 100 == 0: if step % 100 == 0:
f1 = evaluate_teacher(teacher_model, dev_ds) f1 = evaluate_teacher(teacher_model, dev_ds)
print('teacher f1: %.5f' % f1) print('teacher f1: %.5f' % f1)
P.save(teacher_model.state_dict(), './teacher_model.bin') P.save(teacher_model.state_dict(),str( './teacher_model.bin'))
else: else:
state_dict = P.load('./teacher_model.bin') state_dict = P.load('./teacher_model.bin')
teacher_model.set_state_dict(state_dict) teacher_model.set_state_dict(state_dict)
......
...@@ -162,7 +162,7 @@ model = ErnieModelForSequenceClassification.from_pretrained( ...@@ -162,7 +162,7 @@ model = ErnieModelForSequenceClassification.from_pretrained(
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint) log.info('loading checkpoint from %s' % args.init_checkpoint)
sd = P.load(args.init_checkpoint) sd = P.load(str(args.init_checkpoint))
model.set_state_dict(sd) model.set_state_dict(sd)
g_clip = P.nn.ClipGradByGlobalNorm(1.0) #experimental g_clip = P.nn.ClipGradByGlobalNorm(1.0) #experimental
...@@ -238,9 +238,9 @@ with LogWriter( ...@@ -238,9 +238,9 @@ with LogWriter(
log_writer.add_scalar('eval/acc', acc, step=step) log_writer.add_scalar('eval/acc', acc, step=step)
log.debug('acc %.5f' % acc) log.debug('acc %.5f' % acc)
if args.save_dir is not None: if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
if args.save_dir is not None: if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
if args.inference_model_dir is not None: if args.inference_model_dir is not None:
class InferenceModel(ErnieModelForSequenceClassification): class InferenceModel(ErnieModelForSequenceClassification):
......
...@@ -128,7 +128,7 @@ model = ErnieModelForSequenceClassification.from_pretrained( ...@@ -128,7 +128,7 @@ model = ErnieModelForSequenceClassification.from_pretrained(
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint) log.info('loading checkpoint from %s' % args.init_checkpoint)
sd = P.load(args.init_checkpoint) sd = P.load(str(args.init_checkpoint))
model.set_state_dict(sd) model.set_state_dict(sd)
model = P.DataParallel(model) model = P.DataParallel(model)
...@@ -195,11 +195,11 @@ with P.amp.auto_cast(enable=args.use_amp): ...@@ -195,11 +195,11 @@ with P.amp.auto_cast(enable=args.use_amp):
#log_writer.add_scalar('eval/acc', acc, step=step) #log_writer.add_scalar('eval/acc', acc, step=step)
log.debug('acc %.5f' % acc) log.debug('acc %.5f' % acc)
if args.save_dir is not None: if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
# exit # exit
if step > args.max_steps: if step > args.max_steps:
break break
if args.save_dir is not None and env.dev_id == 0: if args.save_dir is not None and env.dev_id == 0:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
log.debug('done') log.debug('done')
...@@ -145,7 +145,7 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features, ...@@ -145,7 +145,7 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features,
log.debug('[step %d] eval result: f1 %.5f em %.5f' % log.debug('[step %d] eval result: f1 %.5f em %.5f' %
(step, f1, em)) (step, f1, em))
if env.dev_id == 0 and args.save_dir is not None: if env.dev_id == 0 and args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
if step > max_steps: if step > max_steps:
break break
...@@ -244,4 +244,4 @@ if __name__ == "__main__": ...@@ -244,4 +244,4 @@ if __name__ == "__main__":
tokenizer, args) tokenizer, args)
log.debug('final eval result: f1 %.5f em %.5f' % (f1, em)) log.debug('final eval result: f1 %.5f em %.5f' % (f1, em))
if env.dev_id == 0 and args.save_dir is not None: if env.dev_id == 0 and args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
...@@ -249,10 +249,10 @@ with LogWriter( ...@@ -249,10 +249,10 @@ with LogWriter(
log.debug('eval f1: %.5f' % f1) log.debug('eval f1: %.5f' % f1)
log_writer.add_scalar('eval/f1', f1, step=step) log_writer.add_scalar('eval/f1', f1, step=step)
if args.save_dir is not None: if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
f1 = evaluate(model, dev_ds) f1 = evaluate(model, dev_ds)
log.debug('final eval f1: %.5f' % f1) log.debug('final eval f1: %.5f' % f1)
log_writer.add_scalar('eval/f1', f1, step=step) log_writer.add_scalar('eval/f1', f1, step=step)
if args.save_dir is not None: if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
...@@ -177,9 +177,9 @@ if not args.eval: ...@@ -177,9 +177,9 @@ if not args.eval:
log.debug('acc %.5f' % acc) log.debug('acc %.5f' % acc)
if args.save_dir is not None: if args.save_dir is not None:
P.save(model.state_dict(), P.save(model.state_dict(),
args.save_dir / 'ckpt.bin') str(args.save_dir / 'ckpt.bin'))
if args.save_dir is not None: if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
else: else:
feature_column = propeller.data.FeatureColumns([ feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn( propeller.data.TextColumn(
...@@ -189,7 +189,7 @@ else: ...@@ -189,7 +189,7 @@ else:
tokenizer=tokenizer.tokenize), tokenizer=tokenizer.tokenize),
]) ])
sd = P.load(args.init_checkpoint) sd = P.load(str(args.init_checkpoint))
model.set_dict(sd) model.set_dict(sd)
model.eval() model.eval()
......
...@@ -394,7 +394,7 @@ if __name__ == '__main__': ...@@ -394,7 +394,7 @@ if __name__ == '__main__':
log.debug(msg) log.debug(msg)
if step % 1000 == 0 and env.dev_id == 0: if step % 1000 == 0 and env.dev_id == 0:
log.debug('saveing...') log.debug('saveing...')
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
if step > args.max_steps: if step > args.max_steps:
break break
log.info('done') log.info('done')
...@@ -401,7 +401,7 @@ if __name__ == '__main__': ...@@ -401,7 +401,7 @@ if __name__ == '__main__':
rev_dict[tokenizer.pad_id] = '' # replace [PAD] rev_dict[tokenizer.pad_id] = '' # replace [PAD]
rev_dict[tokenizer.unk_id] = '' # replace [PAD] rev_dict[tokenizer.unk_id] = '' # replace [PAD]
sd = P.load(args.save_dir) sd = P.load(str(args.save_dir))
ernie.set_state_dict(sd) ernie.set_state_dict(sd)
def map_fn(src_ids): def map_fn(src_ids):
......
...@@ -308,7 +308,7 @@ def seq2seq(model, tokenizer, args): ...@@ -308,7 +308,7 @@ def seq2seq(model, tokenizer, args):
log.debug(msg) log.debug(msg)
if args.save_dir is not None and step % 1000 == 0 and env.dev_id == 0: if args.save_dir is not None and step % 1000 == 0 and env.dev_id == 0:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
if args.predict_output_dir is not None and step > args.skip_eval_steps and step % args.eval_steps == 0: if args.predict_output_dir is not None and step > args.skip_eval_steps and step % args.eval_steps == 0:
assert args.predict_output_dir.exists(), \ assert args.predict_output_dir.exists(), \
...@@ -320,7 +320,7 @@ def seq2seq(model, tokenizer, args): ...@@ -320,7 +320,7 @@ def seq2seq(model, tokenizer, args):
evaluate(model, dev_ds, step, args) evaluate(model, dev_ds, step, args)
if args.save_dir is not None: if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin') P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
if __name__ == '__main__': if __name__ == '__main__':
...@@ -414,7 +414,7 @@ if __name__ == '__main__': ...@@ -414,7 +414,7 @@ if __name__ == '__main__':
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint) log.info('loading checkpoint from %s' % args.init_checkpoint)
sd = P.load(args.init_checkpoint) sd = P.load(str(args.init_checkpoint))
ernie.set_state_dict(sd) ernie.set_state_dict(sd)
seq2seq(ernie, tokenizer, args) seq2seq(ernie, tokenizer, args)
...@@ -290,7 +290,7 @@ class PretrainedModel(object): ...@@ -290,7 +290,7 @@ class PretrainedModel(object):
# log.debug('load pretrained weight from program state') # log.debug('load pretrained weight from program state')
# F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix # F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix
if state_dict_path.exists(): if state_dict_path.exists():
m = P.load(state_dict_path) m = P.load(str(state_dict_path))
for k, v in model.state_dict().items(): for k, v in model.state_dict().items():
if k not in m: if k not in m:
log.warn('param:%s not set in pretrained model, skip' % k) log.warn('param:%s not set in pretrained model, skip' % k)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册