tracker.cpp 6.4 KB
Newer Older
U
Umedzhon Abdumuminov 已提交
1 2 3 4 5 6 7
/*******************************************************************************
 * Copyright (C) 2018-2020 Intel Corporation
 *
 * SPDX-License-Identifier: MIT
 ******************************************************************************/

#include "tracker.h"
U
Umedzhon Abdumuminov 已提交
8
#include "mapped_mat.h"
U
Umedzhon Abdumuminov 已提交
9 10

#include "gva_utils.h"
S
Smertin, Dmitry 已提交
11
#include "utils.h"
U
Umedzhon Abdumuminov 已提交
12 13 14 15 16 17 18 19 20 21
#include "video_frame.h"

#include <functional>

using namespace VasWrapper;

namespace {

const int DEFAULT_MAX_NUM_OBJECTS = -1;

S
Smertin, Dmitry 已提交
22
const constexpr int NO_ASSOCIATION = -1;
U
Umedzhon Abdumuminov 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36

inline bool CaseInsCharCompareN(char a, char b) {
    return (toupper(a) == toupper(b));
}

inline bool CaseInsCompare(const std::string &s1, const std::string &s2) {
    return ((s1.size() == s2.size()) && std::equal(s1.begin(), s1.end(), s2.begin(), CaseInsCharCompareN));
}

vas::ot::TrackingType trackingType(const std::string &tracking_type) {
    if (CaseInsCompare(tracking_type, "ZERO_TERM")) {
        return vas::ot::TrackingType::ZERO_TERM;
    } else if (CaseInsCompare(tracking_type, "SHORT_TERM")) {
        return vas::ot::TrackingType::SHORT_TERM;
S
Smertin, Dmitry 已提交
37 38 39 40
    } else if (CaseInsCompare(tracking_type, "ZERO_TERM_IMAGELESS")) {
        return vas::ot::TrackingType::ZERO_TERM_IMAGELESS;
    } else if (CaseInsCompare(tracking_type, "SHORT_TERM_IMAGELESS")) {
        return vas::ot::TrackingType::SHORT_TERM_IMAGELESS;
U
Umedzhon Abdumuminov 已提交
41
    } else {
S
Smertin, Dmitry 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
        throw std::invalid_argument("Unknown tracking name " + tracking_type);
    }
}

vas::BackendType backendType(const std::string &backend_type) {
    if (CaseInsCompare(backend_type, "CPU")) {
        return vas::BackendType::CPU;
    } else if (CaseInsCompare(backend_type, "VPU")) {
        return vas::BackendType::VPU;
    } else if (CaseInsCompare(backend_type, "GPU")) {
        return vas::BackendType::GPU;
    } else if (CaseInsCompare(backend_type, "FPGA")) {
        return vas::BackendType::FPGA;
    } else if (CaseInsCompare(backend_type, "HDDL")) {
        return vas::BackendType::HDDL;
    } else {
        throw std::invalid_argument("Unknown tracking device " + backend_type);
U
Umedzhon Abdumuminov 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71
    }
}

vas::ColorFormat ConvertFormat(GstVideoFormat format) {
    switch (format) {
    case GST_VIDEO_FORMAT_BGR:
        return vas::ColorFormat::BGR;
    case GST_VIDEO_FORMAT_BGRx:
        return vas::ColorFormat::BGRX;
    case GST_VIDEO_FORMAT_BGRA:
        return vas::ColorFormat::BGRX;
    case GST_VIDEO_FORMAT_NV12:
        return vas::ColorFormat::NV12;
F
Fadeev Alexey 已提交
72 73
    case GST_VIDEO_FORMAT_I420:
        return vas::ColorFormat::I420;
U
Umedzhon Abdumuminov 已提交
74 75 76 77 78 79 80 81
    default:
        return vas::ColorFormat::BGR;
    }
}

std::vector<vas::ot::DetectedObject> extractDetectedObjects(GVA::VideoFrame &video_frame,
                                                            std::unordered_map<int, std::string> &labels) {
    std::vector<vas::ot::DetectedObject> detected_objects;
U
Umedzhon Abdumuminov 已提交
82 83
    for (GVA::RegionOfInterest &roi : video_frame.regions()) {
        int label_id = roi.detection().get_int("label_id", std::numeric_limits<int>::max());
U
Umedzhon Abdumuminov 已提交
84
        if (labels.find(label_id) == labels.end())
U
Umedzhon Abdumuminov 已提交
85 86 87
            labels[label_id] = roi.label();
        auto rect = roi.rect();
        cv::Rect obj_rect(rect.x, rect.y, rect.w, rect.h);
U
Umedzhon Abdumuminov 已提交
88 89 90 91 92 93 94
        detected_objects.emplace_back(obj_rect, label_id);
    }
    return detected_objects;
}

void append(GVA::VideoFrame &video_frame, const vas::ot::Object &tracked_object, const std::string &label) {
    auto roi = video_frame.add_region(tracked_object.rect.x, tracked_object.rect.y, tracked_object.rect.width,
F
Fadeev Alexey 已提交
95
                                      tracked_object.rect.height, label, 1.0);
U
Umedzhon Abdumuminov 已提交
96
    roi.detection().set_int("label_id", tracked_object.class_label);
A
Alexey Fadeev 已提交
97
    roi.set_object_id(tracked_object.tracking_id);
U
Umedzhon Abdumuminov 已提交
98 99 100 101
}

} // namespace

S
Smertin, Dmitry 已提交
102 103
Tracker::Tracker(const GstGvaTrack *gva_track, const std::string &tracking_type) : gva_track(gva_track) {
    if (tracking_type.empty() || gva_track == nullptr) {
U
Umedzhon Abdumuminov 已提交
104 105 106
        throw std::invalid_argument("Tracker::Tracker: nullptr arguments is not allowed");
    }

S
Smertin, Dmitry 已提交
107 108 109 110 111 112 113 114 115 116 117 118
    vas::ot::ObjectTracker::Builder builder;
    builder.input_image_format = ConvertFormat(gva_track->info->finfo->format);
    builder.max_num_objects = DEFAULT_MAX_NUM_OBJECTS;

    // examples: VPU.1, CPU, VPU, etc.
    std::vector<std::string> full_device = Utils::splitString(gva_track->device, '.');
    builder.backend_type = backendType(full_device[0]);
    if (builder.backend_type == vas::BackendType::VPU and full_device.size() > 1) {
        std::map<std::string, std::string> config;
        config["device_id"] = full_device[1];
        builder.platform_config = config;
    }
U
Umedzhon Abdumuminov 已提交
119

S
Smertin, Dmitry 已提交
120
    object_tracker = builder.Build(trackingType(tracking_type));
U
Umedzhon Abdumuminov 已提交
121 122 123 124 125 126
}

void Tracker::track(GstBuffer *buffer) {
    if (buffer == nullptr)
        throw std::invalid_argument("buffer is nullptr");
    try {
S
Smertin, Dmitry 已提交
127
        GVA::VideoFrame video_frame(buffer, gva_track->info);
U
Umedzhon Abdumuminov 已提交
128 129
        std::vector<vas::ot::DetectedObject> detected_objects = extractDetectedObjects(video_frame, labels);

S
Smertin, Dmitry 已提交
130 131
        MappedMat cv_mat(buffer, gva_track->info, GST_MAP_READ);
        std::vector<GVA::RegionOfInterest> regions = video_frame.regions();
U
Umedzhon Abdumuminov 已提交
132
        const auto tracked_objects = object_tracker->Track(cv_mat.mat(), detected_objects);
U
Umedzhon Abdumuminov 已提交
133 134 135
        for (const auto &tracked_object : tracked_objects) {
            if (tracked_object.status == vas::ot::TrackingStatus::LOST)
                continue;
S
Smertin, Dmitry 已提交
136 137 138 139 140 141 142
            if (tracked_object.association_idx != NO_ASSOCIATION)
                regions[tracked_object.association_idx].set_object_id(tracked_object.tracking_id);
            else {
                auto it = labels.find(tracked_object.class_label);
                std::string label = it != labels.end() ? it->second : std::string();
                append(video_frame, tracked_object, label);
            }
U
Umedzhon Abdumuminov 已提交
143 144 145 146 147 148 149
        }
    } catch (const std::exception &e) {
        GST_ERROR("Exception within tracker occured: %s", e.what());
        throw std::runtime_error("Track: error while tracking objects");
    }
}

S
Smertin, Dmitry 已提交
150 151 152
// TODO: use one typed function instead of four. Second arg is GVA-level tracking type enum (get rid of strings)
ITracker *Tracker::CreateShortTerm(const GstGvaTrack *gva_track) {
    return new Tracker(gva_track, "SHORT_TERM");
U
Umedzhon Abdumuminov 已提交
153 154
}

S
Smertin, Dmitry 已提交
155 156
ITracker *Tracker::CreateZeroTerm(const GstGvaTrack *gva_track) {
    return new Tracker(gva_track, "ZERO_TERM");
U
Umedzhon Abdumuminov 已提交
157
}
S
Smertin, Dmitry 已提交
158 159 160 161 162 163 164

ITracker *Tracker::CreateShortTermImageless(const GstGvaTrack *gva_track) {
    return new Tracker(gva_track, "SHORT_TERM_IMAGELESS");
}

ITracker *Tracker::CreateZeroTermImageless(const GstGvaTrack *gva_track) {
    return new Tracker(gva_track, "ZERO_TERM_IMAGELESS");
A
Alexey Fadeev 已提交
165
}