From de81b6fb8c17eef1373c646e61310b18c9c6fa8a Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Wed, 12 Feb 2020 14:48:30 +0800 Subject: [PATCH] fix load for yolo_dy (#4245) * fix load for yolo_dy --- dygraph/yolov3/models/yolov3.py | 4 ++-- dygraph/yolov3/train.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/dygraph/yolov3/models/yolov3.py b/dygraph/yolov3/models/yolov3.py index dee29926..b49c9f63 100755 --- a/dygraph/yolov3/models/yolov3.py +++ b/dygraph/yolov3/models/yolov3.py @@ -24,8 +24,8 @@ from paddle.fluid.regularizer import L2Decay from config import cfg from paddle.fluid.dygraph.nn import Conv2D, BatchNorm -from darknet import DarkNet53_conv_body -from darknet import ConvBNLayer +from .darknet import DarkNet53_conv_body +from .darknet import ConvBNLayer from paddle.fluid.dygraph.base import to_variable diff --git a/dygraph/yolov3/train.py b/dygraph/yolov3/train.py index b7e2b6ab..7c1548e6 100755 --- a/dygraph/yolov3/train.py +++ b/dygraph/yolov3/train.py @@ -80,16 +80,17 @@ def train(): if args.use_data_parallel: strategy = fluid.dygraph.parallel.prepare_context() model = YOLOv3(3, is_train=True) - if args.use_data_parallel: - model = fluid.dygraph.parallel.DataParallel(model, strategy) if cfg.pretrain: restore, _ = fluid.load_dygraph(cfg.pretrain) - model.blocks.set_dict(restore) + model.block.set_dict(restore) if cfg.finetune: restore, _ = fluid.load_dygraph(cfg.finetune) - model.set_dict(restore) + model.set_dict(restore, use_structured_name=True) + + if args.use_data_parallel: + model = fluid.dygraph.parallel.DataParallel(model, strategy) boundaries = cfg.lr_steps gamma = cfg.lr_gamma -- GitLab