未验证 提交 730b87a9 编写于 作者: L Lyon 提交者: GitHub

Merge pull request #16 from Oneflow-Inc/dev_test_map

replace depreciated function with new
......@@ -66,7 +66,7 @@ args = parser.parse_args()
flow.config.load_library(oneflow_yolov3.lib_path())
func_config = flow.FunctionConfig()
func_config.default_distribute_strategy(flow.scope.consistent_view())
func_config.default_logical_view(flow.scope.consistent_view())
func_config.default_data_type(flow.float)
if args.use_tensorrt != 0:
func_config.use_tensorrt(True)
......
......@@ -4,6 +4,7 @@ import os
import numpy as np
import oneflow as flow
import oneflow.typing as tp
import utils
from data_preprocess import batch_image_preprocess_with_label
from tqdm import tqdm
......@@ -54,38 +55,22 @@ parser.add_argument(
default=0,
required=False)
parser.add_argument("-image_paths", "--image_paths", type=str, required=False)
args = parser.parse_args()
flow.config.load_library(oneflow_yolov3.lib_path())
func_config = flow.FunctionConfig()
func_config.default_distribute_strategy(flow.scope.consistent_view())
func_config.default_logical_view(flow.scope.consistent_view())
func_config.default_data_type(flow.float)
if args.use_tensorrt != 0:
func_config.use_tensorrt(True)
# func_config.tensorrt.use_fp16()
input_blob_def_dict = {
"images": flow.FixedTensorDef(
(args.batch_size,
3,
args.image_height,
args.image_width),
dtype=flow.float),
"origin_image_info": flow.FixedTensorDef(
(args.batch_size,
2),
dtype=flow.int32),
}
@flow.global_function(func_config)
def yolo_user_op_eval_job(
images=input_blob_def_dict["images"],
origin_image_info=input_blob_def_dict["origin_image_info"]):
def yolo_user_op_eval_job(images:tp.Numpy.Placeholder((args.batch_size, 3, args.image_height, args.image_width), dtype=flow.float),
origin_image_info:tp.Numpy.Placeholder((args.batch_size, 2), dtype=flow.int32)
):
yolo_pos_result, yolo_prob_result = YoloPredictNet(
images, origin_image_info, trainable=False)
return yolo_pos_result, yolo_prob_result, origin_image_info
......@@ -100,7 +85,7 @@ if __name__ == "__main__":
names = f.read().split('\n')
names = list(filter(None, names))
flow.config.gpu_device_num(args.gpu_num_per_node)
# load model
# Load model
check_point = flow.train.CheckPoint()
check_point.load(args.model_load_dir)
......@@ -110,6 +95,8 @@ if __name__ == "__main__":
path_list.append(line.strip('\n'))
iter_num = math.floor(len(path_list) / float(args.batch_size))
# evaluate mAP
"""
reference:
https://github.com/ultralytics/yolov3/blob/master/test.py
......
......@@ -97,7 +97,7 @@ args = parser.parse_args()
flow.config.gpu_device_num(args.gpu_num_per_node)
flow.config.load_library(oneflow_yolov3.lib_path())
func_config = flow.FunctionConfig()
func_config.default_distribute_strategy(flow.scope.consistent_view())
func_config.default_logical_view(flow.scope.consistent_view())
func_config.default_data_type(flow.float)
func_config.train.primary_lr(args.base_lr)
func_config.train.model_update_conf(dict(naive_conf={}))
......
......@@ -9,5 +9,5 @@ python3 oneflow_yolov3/model/yolo_train.py \
--num_epoch=130 --model_load_dir=$model_dir \
--classes=80 --num_boxes=90 --save_frequency=1 \
--model_save_dir="save_model" \
--dataset_dir="data/COCO/test_trainvalno5k.txt"
# --dataset_dir="data/trainvalno5k.txt"
--dataset_dir="data/test_trainvalno5k.txt"
# --dataset_dir="data/COCO/trainvalno5k.txt"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册