提交 2bec96c1 编写于 作者: D dengkaipeng

remove model_cfg.

上级 b2389e38
...@@ -39,7 +39,7 @@ def eval(): ...@@ -39,7 +39,7 @@ def eval():
if not os.path.exists('output'): if not os.path.exists('output'):
os.mkdir('output') os.mkdir('output')
model = YOLOv3(cfg.model_cfg_path, is_train=False) model = YOLOv3(is_train=False)
model.build_model() model.build_model()
outputs = model.get_pred() outputs = model.get_pred()
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
......
...@@ -17,7 +17,7 @@ def infer(): ...@@ -17,7 +17,7 @@ def infer():
if not os.path.exists('output'): if not os.path.exists('output'):
os.mkdir('output') os.mkdir('output')
model = YOLOv3(cfg.model_cfg_path, is_train=False) model = YOLOv3(is_train=False)
model.build_model() model.build_model()
outputs = model.get_pred() outputs = model.get_pred()
input_size = cfg.input_size input_size = cfg.input_size
......
...@@ -56,11 +56,9 @@ def upsample(input, scale=2,name=None): ...@@ -56,11 +56,9 @@ def upsample(input, scale=2,name=None):
class YOLOv3(object): class YOLOv3(object):
def __init__(self, def __init__(self,
model_cfg_path,
is_train=True, is_train=True,
use_pyreader=True, use_pyreader=True,
use_random=True): use_random=True):
self.model_cfg_path = model_cfg_path
self.is_train = is_train self.is_train = is_train
self.use_pyreader = use_pyreader self.use_pyreader = use_pyreader
self.use_random = use_random self.use_random = use_random
......
...@@ -42,7 +42,7 @@ def train(): ...@@ -42,7 +42,7 @@ def train():
if not os.path.exists(cfg.model_save_dir): if not os.path.exists(cfg.model_save_dir):
os.makedirs(cfg.model_save_dir) os.makedirs(cfg.model_save_dir)
model = YOLOv3(cfg.model_cfg_path, use_pyreader=cfg.use_pyreader) model = YOLOv3(use_pyreader=cfg.use_pyreader)
model.build_model() model.build_model()
input_size = cfg.input_size input_size = cfg.input_size
loss = model.loss() loss = model.loss()
......
...@@ -96,7 +96,6 @@ def parse_args(): ...@@ -96,7 +96,6 @@ def parse_args():
# ENV # ENV
add_arg('parallel', bool, True, "Whether use parallel.") add_arg('parallel', bool, True, "Whether use parallel.")
add_arg('use_gpu', bool, True, "Whether use GPU.") add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('model_cfg_path', str, 'config/yolov3.cfg', "YOLO model config file path.")
add_arg('model_save_dir', str, 'checkpoints', "The path to save model.") add_arg('model_save_dir', str, 'checkpoints', "The path to save model.")
add_arg('pretrain', str, 'weights/darknet53', "The pretrain model path.") add_arg('pretrain', str, 'weights/darknet53', "The pretrain model path.")
add_arg('weights', str, 'weights/yolov3', "The weights path.") add_arg('weights', str, 'weights/yolov3', "The weights path.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册