From f7329ddf74fcc97235226bc02edd585a62ad3bf6 Mon Sep 17 00:00:00 2001 From: cuicheng01 Date: Mon, 12 Jul 2021 07:45:44 +0000 Subject: [PATCH] support load pretrain from url --- ppcls/configs/Products/ResNet50_vd_Inshop.yaml | 4 +--- ppcls/configs/Products/ResNet50_vd_SOP.yaml | 4 +--- ppcls/engine/trainer.py | 11 ++++++++--- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ppcls/configs/Products/ResNet50_vd_Inshop.yaml b/ppcls/configs/Products/ResNet50_vd_Inshop.yaml index 0a0e2725..b29a3a3f 100644 --- a/ppcls/configs/Products/ResNet50_vd_Inshop.yaml +++ b/ppcls/configs/Products/ResNet50_vd_Inshop.yaml @@ -1,9 +1,7 @@ # global configs Global: checkpoints: null -# please download pretrained model via this link: -# https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/product_ResNet50_vd_Aliproduct_v1.0_pretrained.pdparams - pretrained_model: product_ResNet50_vd_Aliproduct_v1.0_pretrained + pretrained_model: "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/product_ResNet50_vd_Aliproduct_v1.0_pretrained.pdparams" output_dir: ./output/ device: gpu save_interval: 10 diff --git a/ppcls/configs/Products/ResNet50_vd_SOP.yaml b/ppcls/configs/Products/ResNet50_vd_SOP.yaml index 795fb026..484b6ff8 100644 --- a/ppcls/configs/Products/ResNet50_vd_SOP.yaml +++ b/ppcls/configs/Products/ResNet50_vd_SOP.yaml @@ -1,9 +1,7 @@ # global configs Global: checkpoints: null -# please download pretrained model via this link: -# https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/product_ResNet50_vd_Aliproduct_v1.0_pretrained.pdparams - pretrained_model: product_ResNet50_vd_Aliproduct_v1.0_pretrained + pretrained_model: "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/product_ResNet50_vd_Aliproduct_v1.0_pretrained.pdparams" output_dir: ./output/ device: gpu save_interval: 10 diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index 9db7fbcf..bb42ca99 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -39,7 +39,7 @@ from ppcls.arch import apply_to_static from ppcls.loss import build_loss from ppcls.metric import build_metrics from ppcls.optimizer import build_optimizer -from ppcls.utils.save_load import load_dygraph_pretrain +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import init_model from ppcls.utils import save_load @@ -77,8 +77,12 @@ class Trainer(object): apply_to_static(self.config, self.model) if self.config["Global"]["pretrained_model"] is not None: - load_dygraph_pretrain(self.model, - self.config["Global"]["pretrained_model"]) + if self.config["Global"]["pretrained_model"].startswith("http"): + load_dygraph_pretrain_from_url( + self.model, self.config["Global"]["pretrained_model"]) + else: + load_dygraph_pretrain( + self.model, self.config["Global"]["pretrained_model"]) if self.config["Global"]["distributed"]: self.model = paddle.DataParallel(self.model) @@ -362,6 +366,7 @@ class Trainer(object): if self.is_rec: out = self.model(batch[0], batch[1]) else: + self.model.eval() out = self.model(batch[0]) # calc loss if self.eval_loss_func is not None: -- GitLab