提交 a1ec4ea3 编写于 作者: A Alexander Alekhin

Merge pull request #21361 from alalek:sample_fix_tracking

#!/usr/bin/env python #!/usr/bin/env python
''' '''
Tracker demo Tracker demo
...@@ -36,43 +35,49 @@ class App(object): ...@@ -36,43 +35,49 @@ class App(object):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.trackerAlgorithm = args.tracker_algo
def initializeTracker(self, image, trackerAlgorithm): self.tracker = self.createTracker()
def createTracker(self):
if self.trackerAlgorithm == 'mil':
tracker = cv.TrackerMIL_create()
elif self.trackerAlgorithm == 'goturn':
params = cv.TrackerGOTURN_Params()
params.modelTxt = self.args.goturn
params.modelBin = self.args.goturn_model
tracker = cv.TrackerGOTURN_create(params)
elif self.trackerAlgorithm == 'dasiamrpn':
params = cv.TrackerDaSiamRPN_Params()
params.model = self.args.dasiamrpn_net
params.kernel_cls1 = self.args.dasiamrpn_kernel_cls1
params.kernel_r1 = self.args.dasiamrpn_kernel_r1
tracker = cv.TrackerDaSiamRPN_create(params)
else:
sys.exit("Tracker {} is not recognized. Please use one of three available: mil, goturn, dasiamrpn.".format(self.trackerAlgorithm))
return tracker
def initializeTracker(self, image):
while True: while True:
if trackerAlgorithm == 'mil':
tracker = cv.TrackerMIL_create()
elif trackerAlgorithm == 'goturn':
params = cv.TrackerGOTURN_Params()
params.modelTxt = self.args.goturn
params.modelBin = self.args.goturn_model
tracker = cv.TrackerGOTURN_create(params)
elif trackerAlgorithm == 'dasiamrpn':
params = cv.TrackerDaSiamRPN_Params()
params.model = self.args.dasiamrpn_net
params.kernel_cls1 = self.args.dasiamrpn_kernel_cls1
params.kernel_r1 = self.args.dasiamrpn_kernel_r1
tracker = cv.TrackerDaSiamRPN_create(params)
else:
sys.exit("Tracker {} is not recognized. Please use one of three available: mil, goturn, dasiamrpn.".format(trackerAlgorithm))
print('==> Select object ROI for tracker ...') print('==> Select object ROI for tracker ...')
bbox = cv.selectROI('tracking', image) bbox = cv.selectROI('tracking', image)
print('ROI: {}'.format(bbox)) print('ROI: {}'.format(bbox))
if bbox[2] <= 0 or bbox[3] <= 0:
sys.exit("ROI selection cancelled. Exiting...")
try: try:
tracker.init(image, bbox) self.tracker.init(image, bbox)
except Exception as e: except Exception as e:
print('Unable to initialize tracker with requested bounding box. Is there any object?') print('Unable to initialize tracker with requested bounding box. Is there any object?')
print(e) print(e)
print('Try again ...') print('Try again ...')
continue continue
return tracker return
def run(self): def run(self):
videoPath = self.args.input videoPath = self.args.input
trackerAlgorithm = self.args.tracker_algo print('Using video: {}'.format(videoPath))
camera = create_capture(videoPath, presets['cube']) camera = create_capture(cv.samples.findFileOrKeep(videoPath), presets['cube'])
if not camera.isOpened(): if not camera.isOpened():
sys.exit("Can't open video stream: {}".format(videoPath)) sys.exit("Can't open video stream: {}".format(videoPath))
...@@ -82,7 +87,7 @@ class App(object): ...@@ -82,7 +87,7 @@ class App(object):
assert image is not None assert image is not None
cv.namedWindow('tracking') cv.namedWindow('tracking')
tracker = self.initializeTracker(image, trackerAlgorithm) self.initializeTracker(image)
print("==> Tracking is started. Press 'SPACE' to re-initialize tracker or 'ESC' for exit...") print("==> Tracking is started. Press 'SPACE' to re-initialize tracker or 'ESC' for exit...")
...@@ -92,7 +97,7 @@ class App(object): ...@@ -92,7 +97,7 @@ class App(object):
print("Can't read frame") print("Can't read frame")
break break
ok, newbox = tracker.update(image) ok, newbox = self.tracker.update(image)
#print(ok, newbox) #print(ok, newbox)
if ok: if ok:
...@@ -101,7 +106,7 @@ class App(object): ...@@ -101,7 +106,7 @@ class App(object):
cv.imshow("tracking", image) cv.imshow("tracking", image)
k = cv.waitKey(1) k = cv.waitKey(1)
if k == 32: # SPACE if k == 32: # SPACE
tracker = self.initializeTracker(image) self.initializeTracker(image)
if k == 27: # ESC if k == 27: # ESC
break break
...@@ -112,22 +117,13 @@ if __name__ == '__main__': ...@@ -112,22 +117,13 @@ if __name__ == '__main__':
print(__doc__) print(__doc__)
parser = argparse.ArgumentParser(description="Run tracker") parser = argparse.ArgumentParser(description="Run tracker")
parser.add_argument("--input", type=str, default="vtest.avi", help="Path to video source") parser.add_argument("--input", type=str, default="vtest.avi", help="Path to video source")
parser.add_argument("--tracker_algo", type=str, default="mil", help="One of three available tracking algorithms: mil, goturn, dasiamrpn") parser.add_argument("--tracker_algo", type=str, default="mil", help="One of available tracking algorithms: mil, goturn, dasiamrpn")
parser.add_argument("--goturn", type=str, default="goturn.prototxt", help="Path to GOTURN architecture") parser.add_argument("--goturn", type=str, default="goturn.prototxt", help="Path to GOTURN architecture")
parser.add_argument("--goturn_model", type=str, default="goturn.caffemodel", help="Path to GOTERN model") parser.add_argument("--goturn_model", type=str, default="goturn.caffemodel", help="Path to GOTERN model")
parser.add_argument("--dasiamrpn_net", type=str, default="dasiamrpn_model.onnx", help="Path to onnx model of DaSiamRPN net") parser.add_argument("--dasiamrpn_net", type=str, default="dasiamrpn_model.onnx", help="Path to onnx model of DaSiamRPN net")
parser.add_argument("--dasiamrpn_kernel_r1", type=str, default="dasiamrpn_kernel_r1.onnx", help="Path to onnx model of DaSiamRPN kernel_r1") parser.add_argument("--dasiamrpn_kernel_r1", type=str, default="dasiamrpn_kernel_r1.onnx", help="Path to onnx model of DaSiamRPN kernel_r1")
parser.add_argument("--dasiamrpn_kernel_cls1", type=str, default="dasiamrpn_kernel_cls1.onnx", help="Path to onnx model of DaSiamRPN kernel_cls1") parser.add_argument("--dasiamrpn_kernel_cls1", type=str, default="dasiamrpn_kernel_cls1.onnx", help="Path to onnx model of DaSiamRPN kernel_cls1")
parser.add_argument("--dasiamrpn_backend", type=int, default=0, help="Choose one of computation backends:\
0: automatically (by default),\
1: Halide language (http://halide-lang.org/),\
2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit),\
3: OpenCV implementation")
parser.add_argument("--dasiamrpn_target", type=int, default=0, help="Choose one of target computation devices:\
0: CPU target (by default),\
1: OpenCL,\
2: OpenCL fp16 (half-float precision),\
3: VPU")
args = parser.parse_args() args = parser.parse_args()
App(args).run() App(args).run()
cv.destroyAllWindows() cv.destroyAllWindows()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册