提交 cd9819a6 编写于 作者: E Evgeny Izutov

Fixes

上级 584ae970
...@@ -15,14 +15,15 @@ If you want to make a contribution please follow [the guideline](CONTRIBUTING.md ...@@ -15,14 +15,15 @@ If you want to make a contribution please follow [the guideline](CONTRIBUTING.md
```Shell ```Shell
git clone https://github.com/opencv/training_toolbox_caffe.git caffe git clone https://github.com/opencv/training_toolbox_caffe.git caffe
cd caffe cd caffe
git checkout master git checkout develop
``` ```
2. Build the code. Please follow [Caffe instruction](http://caffe.berkeleyvision.org/installation.html) to install all necessary packages and build it. 2. Build the code. Please follow [Caffe instruction](http://caffe.berkeleyvision.org/installation.html) to install all necessary packages and build it.
```Shell ```Shell
sudo pip install -r $CAFFE_ROOT/python/requirements.txt
mkdir build && cd build mkdir build && cd build
cmake .. cmake ..
make -j8 make -j8
# Make sure to include $CAFFE_ROOT/python to your PYTHONPATH. export PYTHONPATH=$PYTHONPATH:$CAFFE_ROOT/python
# (Optional) # (Optional)
make runtest -j8 make runtest -j8
make pytest make pytest
......
...@@ -535,6 +535,9 @@ def calc_mr_ap(scores, true_positives, false_positives, num_gt, num_images, fppi ...@@ -535,6 +535,9 @@ def calc_mr_ap(scores, true_positives, false_positives, num_gt, num_images, fppi
return 0.5 * (miss_rates[left_position] + miss_rates[right_position]) return 0.5 * (miss_rates[left_position] + miss_rates[right_position])
if len(true_positives) == 0 or np.sum(true_positives) == 0:
return 1.0, 0.0
sorted_ind = np.argsort(-scores) sorted_ind = np.argsort(-scores)
fp_sorted = false_positives[sorted_ind] fp_sorted = false_positives[sorted_ind]
tp_sorted = true_positives[sorted_ind] tp_sorted = true_positives[sorted_ind]
...@@ -696,7 +699,7 @@ def main(): ...@@ -696,7 +699,7 @@ def main():
if len([True for b in bbox_list if b.label == class_id]) > 0] if len([True for b in bbox_list if b.label == class_id]) > 0]
class_pred_image_ids = [img_id for img_id, bbox_list in iteritems(predicted_actions) class_pred_image_ids = [img_id for img_id, bbox_list in iteritems(predicted_actions)
if len([True for b in bbox_list if b.label == class_id]) > 0] if len([True for b in bbox_list if b.label == class_id]) > 0]
class_num_images = np.sum(np.unique(class_gt_image_ids + class_pred_image_ids)) class_num_images = np.sum(len(np.unique(class_gt_image_ids + class_pred_image_ids)))
class_mr, class_ap = calc_mr_ap(class_scores, class_tp, class_fp, class_num_gt_bboxes, class_num_images) class_mr, class_ap = calc_mr_ap(class_scores, class_tp, class_fp, class_num_gt_bboxes, class_num_images)
print(' {}: AP: {:.2f} miss_rate@0.1: {:.2f}' print(' {}: AP: {:.2f} miss_rate@0.1: {:.2f}'
......
...@@ -134,6 +134,13 @@ class SampleDataFromDisk(object): ...@@ -134,6 +134,13 @@ class SampleDataFromDisk(object):
total_frames += 1 total_frames += 1
total_objects += len(gt_objects) total_objects += len(gt_objects)
for class_id in self._class_queues:
if class_id == ignore_class_id:
continue
if len(self._class_queues[class_id]) == 0:
raise Exception('Cannot find frames with {} action label'.format(class_id))
LOG('DataLayer stats: loaded {} frames with {} objects.'.format(total_frames, total_objects)) LOG('DataLayer stats: loaded {} frames with {} objects.'.format(total_frames, total_objects))
self._print_stat(glob_class_counts, self._class_queues, self._ignore_class_id) self._print_stat(glob_class_counts, self._class_queues, self._ignore_class_id)
......
...@@ -220,7 +220,7 @@ class DetMatcherLayer(BaseLayer): ...@@ -220,7 +220,7 @@ class DetMatcherLayer(BaseLayer):
overlaps = intersection_area / union_area overlaps = intersection_area / union_area
overlaps[np.less_equal(union_area, 0.0)] = 0.0 overlaps[np.less_equal(union_area, 0.0)] = 0.0
return 1.0 - intersection_area / union_area return 1.0 - overlaps
matched_detections = {} matched_detections = {}
for item_id in gt_data: for item_id in gt_data:
......
...@@ -10,7 +10,7 @@ networkx>=1.8.1 ...@@ -10,7 +10,7 @@ networkx>=1.8.1
nose>=1.3.0 nose>=1.3.0
pandas>=0.12.0 pandas>=0.12.0
python-dateutil>=2.6.0 python-dateutil>=2.6.0
protobuf>=2.5.0 protobuf==2.6.1
python-gflags>=2.0 python-gflags>=2.0
pyyaml>=3.10 pyyaml>=3.10
Pillow>=5.1.0 Pillow>=5.1.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册