未验证 提交 de81b6fb 编写于 作者: X xiaoting 提交者: GitHub

fix load for yolo_dy (#4245)

* fix load for yolo_dy
上级 8afd79b2
......@@ -24,8 +24,8 @@ from paddle.fluid.regularizer import L2Decay
from config import cfg
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
from darknet import DarkNet53_conv_body
from darknet import ConvBNLayer
from .darknet import DarkNet53_conv_body
from .darknet import ConvBNLayer
from paddle.fluid.dygraph.base import to_variable
......
......@@ -80,16 +80,17 @@ def train():
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
model = YOLOv3(3, is_train=True)
if args.use_data_parallel:
model = fluid.dygraph.parallel.DataParallel(model, strategy)
if cfg.pretrain:
restore, _ = fluid.load_dygraph(cfg.pretrain)
model.blocks.set_dict(restore)
model.block.set_dict(restore)
if cfg.finetune:
restore, _ = fluid.load_dygraph(cfg.finetune)
model.set_dict(restore)
model.set_dict(restore, use_structured_name=True)
if args.use_data_parallel:
model = fluid.dygraph.parallel.DataParallel(model, strategy)
boundaries = cfg.lr_steps
gamma = cfg.lr_gamma
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册