提交 0f579fe0 编写于 作者: W Waleed Abdulla

Cleanup: split GT boxes and class IDs tensors.

上级 459dbb74
......@@ -438,11 +438,12 @@
],
"source": [
"image_id = np.random.choice(dataset.image_ids, 1)[0]\n",
"image, image_meta, bbox, mask = modellib.load_image_gt(\n",
"image, image_meta, class_ids, bbox, mask = modellib.load_image_gt(\n",
" dataset, config, image_id, use_mini_mask=False)\n",
"\n",
"log(\"image\", image)\n",
"log(\"image_meta\", image_meta)\n",
"log(\"class_ids\", class_ids)\n",
"log(\"bbox\", bbox)\n",
"log(\"mask\", mask)\n",
"\n",
......@@ -466,7 +467,7 @@
}
],
"source": [
"visualize.display_instances(image, bbox[:,:4], mask, bbox[:,4], dataset.class_names)"
"visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names)"
]
},
{
......@@ -496,7 +497,7 @@
],
"source": [
"# Add augmentation and mask resizing.\n",
"image, image_meta, bbox, mask = modellib.load_image_gt(\n",
"image, image_meta, class_ids, bbox, mask = modellib.load_image_gt(\n",
" dataset, config, image_id, augment=True, use_mini_mask=True)\n",
"log(\"mask\", mask)\n",
"display_images([image]+[mask[:,:,i] for i in range(min(mask.shape[-1], 7))])"
......@@ -520,7 +521,7 @@
],
"source": [
"mask = utils.expand_mask(bbox, mask, image.shape)\n",
"visualize.display_instances(image, bbox[:,:4], mask, bbox[:,4], dataset.class_names)"
"visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names)"
]
},
{
......@@ -634,7 +635,7 @@
"\n",
"# Load and draw random image\n",
"image_id = np.random.choice(dataset.image_ids, 1)[0]\n",
"image, image_meta, _, _ = modellib.load_image_gt(dataset, config, image_id)\n",
"image, image_meta, _, _, _ = modellib.load_image_gt(dataset, config, image_id)\n",
"fig, ax = plt.subplots(1, figsize=(10, 10))\n",
"ax.imshow(image)\n",
"levels = len(config.BACKBONE_SHAPES)\n",
......@@ -313,21 +313,17 @@
],
"source": [
"image_id = random.choice(dataset.image_ids)\n",
"image, image_meta, gt_bbox, gt_mask =\\\n",
"image, image_meta, gt_class_id, gt_bbox, gt_mask =\\\n",
" modellib.load_image_gt(dataset, config, image_id, use_mini_mask=False)\n",
"info = dataset.image_info[image_id]\n",
"print(\"image ID: {}.{} ({}) {}\".format(info[\"source\"], info[\"id\"], image_id, \n",
" dataset.image_reference(image_id)))\n",
"gt_class_id = gt_bbox[:, 4]\n",
"\n",
"# Run object detection\n",
"results = model.detect([image], verbose=1)\n",
"\n",
"# Display results\n",
"ax = get_ax(1)\n",
"r = results[0]\n",
"# visualize.display_instances(image, gt_bbox[:,:4], gt_mask, gt_bbox[:,4], \n",
"# dataset.class_names, ax=ax[0], title=\"Ground Truth\")\n",
"visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], \n",
" dataset.class_names, r['scores'], ax=ax,\n",
" title=\"Predictions\")\n",
......@@ -361,7 +357,7 @@
],
"source": [
"# Draw precision-recall curve\n",
"AP, precisions, recalls, overlaps = utils.compute_ap(gt_bbox[:,:4], gt_bbox[:,4], \n",
"AP, precisions, recalls, overlaps = utils.compute_ap(gt_bbox, gt_class_id, \n",
" r['rois'], r['class_ids'], r['scores'])\n",
"visualize.plot_precision_recall(AP, precisions, recalls)"
]
......@@ -384,7 +380,7 @@
],
"source": [
"# Grid of ground truth objects and their predictions\n",
"visualize.plot_overlaps(gt_bbox[:, 4], r['class_ids'], r['scores'],\n",
"visualize.plot_overlaps(gt_class_id, r['class_ids'], r['scores'],\n",
" overlaps, dataset.class_names)"
]
},
......@@ -422,7 +418,7 @@
" APs = []\n",
" for image_id in image_ids:\n",
" # Load image\n",
" image, image_meta, gt_bbox, gt_mask =\\\n",
" image, image_meta, gt_class_id, gt_bbox, gt_mask =\\\n",
" modellib.load_image_gt(dataset, config,\n",
" image_id, use_mini_mask=False)\n",
" # Run object detection\n",
......@@ -430,7 +426,7 @@
" # Compute AP\n",
" r = results[0]\n",
" AP, precisions, recalls, overlaps =\\\n",
" utils.compute_ap(gt_bbox[:,:4], gt_bbox[:,4],\n",
" utils.compute_ap(gt_bbox, gt_class_id,\n",
" r['rois'], r['class_ids'], r['scores'])\n",
" APs.append(AP)\n",
" return APs\n",
......@@ -493,7 +489,7 @@
"# target_rpn_match is 1 for positive anchors, -1 for negative anchors\n",
"# and 0 for neutral anchors.\n",
"target_rpn_match, target_rpn_bbox = modellib.build_rpn_targets(\n",
" image.shape, model.anchors, gt_bbox, model.config)\n",
" image.shape, model.anchors, gt_class_id, gt_bbox, model.config)\n",
"log(\"target_rpn_match\", target_rpn_match)\n",
"log(\"target_rpn_bbox\", target_rpn_bbox)\n",
"\n",
......@@ -568,8 +564,9 @@
"pillar = model.keras_model.get_layer(\"ROI\").output # node to start searching from\n",
"\n",
"# TF 1.4 introduces a new version of NMS. Search for both names to support TF 1.3 and 1.4\n",
"nms_node = (model.ancestor(pillar, \"ROI/rpn_non_max_suppression:0\")\n",
" or model.ancestor(pillar, \"ROI/rpn_non_max_suppression/NonMaxSuppressionV2:0\"))\n",
"nms_node = model.ancestor(pillar, \"ROI/rpn_non_max_suppression:0\")\n",
"if nms_node is None:\n",
" nms_node = model.ancestor(pillar, \"ROI/rpn_non_max_suppression/NonMaxSuppressionV2:0\")\n",
"\n",
"rpn = model.run_graph([image], [\n",
" (\"rpn_class\", model.keras_model.get_layer(\"rpn_class\").output),\n",
此差异已折叠。
......@@ -907,16 +907,17 @@
"source": [
"# Test on a random image\n",
"image_id = random.choice(dataset_val.image_ids)\n",
"original_image, image_meta, gt_bbox, gt_mask =\\\n",
"original_image, image_meta, gt_class_id, gt_bbox, gt_mask =\\\n",
" modellib.load_image_gt(dataset_val, inference_config, \n",
" image_id, use_mini_mask=False)\n",
"\n",
"log(\"original_image\", original_image)\n",
"log(\"image_meta\", image_meta)\n",
"log(\"gt_class_id\", gt_bbox)\n",
"log(\"gt_bbox\", gt_bbox)\n",
"log(\"gt_mask\", gt_mask)\n",
"\n",
"visualize.display_instances(original_image, gt_bbox[:,:4], gt_mask, gt_bbox[:,4], \n",
"visualize.display_instances(original_image, gt_bbox, gt_mask, gt_class_id, \n",
" dataset_train.class_names, figsize=(8, 8))"
]
},
......@@ -981,7 +982,7 @@
"APs = []\n",
"for image_id in image_ids:\n",
" # Load image and ground truth data\n",
" image, image_meta, gt_bbox, gt_mask =\\\n",
" image, image_meta, gt_class_id, gt_bbox, gt_mask =\\\n",
" modellib.load_image_gt(dataset_val, inference_config,\n",
" image_id, use_mini_mask=False)\n",
" molded_images = np.expand_dims(modellib.mold_image(image, inference_config), 0)\n",
......@@ -990,7 +991,7 @@
" r = results[0]\n",
" # Compute AP\n",
" AP, precisions, recalls, overlaps =\\\n",
" utils.compute_ap(gt_bbox[:,:4], gt_bbox[:,4],\n",
" utils.compute_ap(gt_bbox, gt_class_id,\n",
" r[\"rois\"], r[\"class_ids\"], r[\"scores\"])\n",
" APs.append(AP)\n",
" \n",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册