未验证 提交 4091aca7 编写于 作者: W Walter 提交者: GitHub

Merge pull request #1314 from RainFrost1/release/2.3

[cherry pick]fix slim load pretrained model bug
......@@ -183,6 +183,11 @@ class Engine(object):
self.model = build_model(self.config["Arch"])
# set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model)
# for slim
self.pruner = get_pruner(self.config, self.model)
self.quanter = get_quaner(self.config, self.model)
# load_pretrain
if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"):
......@@ -192,10 +197,6 @@ class Engine(object):
load_dygraph_pretrain(
self.model, self.config["Global"]["pretrained_model"])
# for slim
self.pruner = get_pruner(self.config, self.model)
self.quanter = get_quaner(self.config, self.model)
# build optimizer
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册