未验证 提交 a2b70aa6 编写于 作者: M Meiyim 提交者: GitHub

Fix for pd22 (#763)

* fix-load-pretrained model

* update readme

* fix path
上级 ff89c2a6
......@@ -173,7 +173,7 @@ data/xnli
- 使用 `动态图` 模型进行finetune:
```script
python3 ./ernie_d/demo/finetune_classifier.py \
python3 ./demo/finetune_classifier.py \
--from_pretrained ernie-1.0 \
--data_dir ./data/xnli
```
......
......@@ -153,7 +153,7 @@ if not os.path.exists('./teacher_model.bin'):
if step % 100 == 0:
f1 = evaluate_teacher(teacher_model, dev_ds)
print('teacher f1: %.5f' % f1)
P.save(teacher_model.state_dict(), './teacher_model.bin')
P.save(teacher_model.state_dict(),str( './teacher_model.bin'))
else:
state_dict = P.load('./teacher_model.bin')
teacher_model.set_state_dict(state_dict)
......
......@@ -162,7 +162,7 @@ model = ErnieModelForSequenceClassification.from_pretrained(
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd = P.load(args.init_checkpoint)
sd = P.load(str(args.init_checkpoint))
model.set_state_dict(sd)
g_clip = P.nn.ClipGradByGlobalNorm(1.0) #experimental
......@@ -238,9 +238,9 @@ with LogWriter(
log_writer.add_scalar('eval/acc', acc, step=step)
log.debug('acc %.5f' % acc)
if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
if args.inference_model_dir is not None:
class InferenceModel(ErnieModelForSequenceClassification):
......
......@@ -128,7 +128,7 @@ model = ErnieModelForSequenceClassification.from_pretrained(
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd = P.load(args.init_checkpoint)
sd = P.load(str(args.init_checkpoint))
model.set_state_dict(sd)
model = P.DataParallel(model)
......@@ -195,11 +195,11 @@ with P.amp.auto_cast(enable=args.use_amp):
#log_writer.add_scalar('eval/acc', acc, step=step)
log.debug('acc %.5f' % acc)
if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
# exit
if step > args.max_steps:
break
if args.save_dir is not None and env.dev_id == 0:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
log.debug('done')
......@@ -145,7 +145,7 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features,
log.debug('[step %d] eval result: f1 %.5f em %.5f' %
(step, f1, em))
if env.dev_id == 0 and args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
if step > max_steps:
break
......@@ -244,4 +244,4 @@ if __name__ == "__main__":
tokenizer, args)
log.debug('final eval result: f1 %.5f em %.5f' % (f1, em))
if env.dev_id == 0 and args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
......@@ -249,10 +249,10 @@ with LogWriter(
log.debug('eval f1: %.5f' % f1)
log_writer.add_scalar('eval/f1', f1, step=step)
if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
f1 = evaluate(model, dev_ds)
log.debug('final eval f1: %.5f' % f1)
log_writer.add_scalar('eval/f1', f1, step=step)
if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
......@@ -177,9 +177,9 @@ if not args.eval:
log.debug('acc %.5f' % acc)
if args.save_dir is not None:
P.save(model.state_dict(),
args.save_dir / 'ckpt.bin')
str(args.save_dir / 'ckpt.bin'))
if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
else:
feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn(
......@@ -189,7 +189,7 @@ else:
tokenizer=tokenizer.tokenize),
])
sd = P.load(args.init_checkpoint)
sd = P.load(str(args.init_checkpoint))
model.set_dict(sd)
model.eval()
......
......@@ -394,7 +394,7 @@ if __name__ == '__main__':
log.debug(msg)
if step % 1000 == 0 and env.dev_id == 0:
log.debug('saveing...')
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
if step > args.max_steps:
break
log.info('done')
......@@ -401,7 +401,7 @@ if __name__ == '__main__':
rev_dict[tokenizer.pad_id] = '' # replace [PAD]
rev_dict[tokenizer.unk_id] = '' # replace [PAD]
sd = P.load(args.save_dir)
sd = P.load(str(args.save_dir))
ernie.set_state_dict(sd)
def map_fn(src_ids):
......
......@@ -308,7 +308,7 @@ def seq2seq(model, tokenizer, args):
log.debug(msg)
if args.save_dir is not None and step % 1000 == 0 and env.dev_id == 0:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
if args.predict_output_dir is not None and step > args.skip_eval_steps and step % args.eval_steps == 0:
assert args.predict_output_dir.exists(), \
......@@ -320,7 +320,7 @@ def seq2seq(model, tokenizer, args):
evaluate(model, dev_ds, step, args)
if args.save_dir is not None:
P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
if __name__ == '__main__':
......@@ -414,7 +414,7 @@ if __name__ == '__main__':
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd = P.load(args.init_checkpoint)
sd = P.load(str(args.init_checkpoint))
ernie.set_state_dict(sd)
seq2seq(ernie, tokenizer, args)
......@@ -290,7 +290,7 @@ class PretrainedModel(object):
# log.debug('load pretrained weight from program state')
# F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix
if state_dict_path.exists():
m = P.load(state_dict_path)
m = P.load(str(state_dict_path))
for k, v in model.state_dict().items():
if k not in m:
log.warn('param:%s not set in pretrained model, skip' % k)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册