提交 f7329ddf 编写于 作者: C cuicheng01

support load pretrain from url

上级 c1f4e463
# global configs # global configs
Global: Global:
checkpoints: null checkpoints: null
# please download pretrained model via this link: pretrained_model: "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/product_ResNet50_vd_Aliproduct_v1.0_pretrained.pdparams"
# https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/product_ResNet50_vd_Aliproduct_v1.0_pretrained.pdparams
pretrained_model: product_ResNet50_vd_Aliproduct_v1.0_pretrained
output_dir: ./output/ output_dir: ./output/
device: gpu device: gpu
save_interval: 10 save_interval: 10
......
# global configs # global configs
Global: Global:
checkpoints: null checkpoints: null
# please download pretrained model via this link: pretrained_model: "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/product_ResNet50_vd_Aliproduct_v1.0_pretrained.pdparams"
# https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/product_ResNet50_vd_Aliproduct_v1.0_pretrained.pdparams
pretrained_model: product_ResNet50_vd_Aliproduct_v1.0_pretrained
output_dir: ./output/ output_dir: ./output/
device: gpu device: gpu
save_interval: 10 save_interval: 10
......
...@@ -39,7 +39,7 @@ from ppcls.arch import apply_to_static ...@@ -39,7 +39,7 @@ from ppcls.arch import apply_to_static
from ppcls.loss import build_loss from ppcls.loss import build_loss
from ppcls.metric import build_metrics from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model from ppcls.utils.save_load import init_model
from ppcls.utils import save_load from ppcls.utils import save_load
...@@ -77,8 +77,12 @@ class Trainer(object): ...@@ -77,8 +77,12 @@ class Trainer(object):
apply_to_static(self.config, self.model) apply_to_static(self.config, self.model)
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(self.model, if self.config["Global"]["pretrained_model"].startswith("http"):
self.config["Global"]["pretrained_model"]) load_dygraph_pretrain_from_url(
self.model, self.config["Global"]["pretrained_model"])
else:
load_dygraph_pretrain(
self.model, self.config["Global"]["pretrained_model"])
if self.config["Global"]["distributed"]: if self.config["Global"]["distributed"]:
self.model = paddle.DataParallel(self.model) self.model = paddle.DataParallel(self.model)
...@@ -362,6 +366,7 @@ class Trainer(object): ...@@ -362,6 +366,7 @@ class Trainer(object):
if self.is_rec: if self.is_rec:
out = self.model(batch[0], batch[1]) out = self.model(batch[0], batch[1])
else: else:
self.model.eval()
out = self.model(batch[0]) out = self.model(batch[0])
# calc loss # calc loss
if self.eval_loss_func is not None: if self.eval_loss_func is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册