提交 c8bb1af5 编写于 作者: P Phil Ferriere 提交者: Waleed Abdulla

Automatically download trained model file

上级 2aa62a39
......@@ -45,10 +45,11 @@
"# Directory to save logs and trained model\n",
"MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n",
"\n",
"# Path to trained weights file\n",
"# Download this file and place in the root of your \n",
"# project (See README file for details)\n",
"# Local path to trained weights file\n",
"COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n",
"# Download COCO trained weights from Releases if needed\n",
"if not os.path.exists(COCO_MODEL_PATH):\n",
" utils.download_trained_weights(COCO_MODEL_PATH)\n",
"\n",
"# Directory of images to run detection on\n",
"IMAGE_DIR = os.path.join(ROOT_DIR, \"images\")"
......@@ -144,6 +145,7 @@
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
......@@ -282,7 +284,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.2"
}
},
"nbformat": 4,
......@@ -49,10 +49,11 @@
"# Directory to save logs and trained model\n",
"MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n",
"\n",
"# Path to trained weights file\n",
"# Download this file and place in the root of your \n",
"# project (See README file for details)\n",
"# Local path to trained weights file\n",
"COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n",
"# Download COCO trained weights from Releases if needed\n",
"if not os.path.exists(COCO_MODEL_PATH):\n",
" utils.download_trained_weights(COCO_MODEL_PATH)\n",
"\n",
"# Path to Shapes trained weights\n",
"SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_shapes.h5\")"
......@@ -1377,7 +1378,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.2"
}
},
"nbformat": 4,
......@@ -44,8 +44,11 @@
"# Directory to save logs and trained model\n",
"MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n",
"\n",
"# Path to COCO trained weights\n",
"# Local path to trained weights file\n",
"COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n",
"# Download COCO trained weights from Releases if needed\n",
"if not os.path.exists(COCO_MODEL_PATH):\n",
" utils.download_trained_weights(COCO_MODEL_PATH)\n",
"\n",
"# Path to Shapes trained weights\n",
"SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_shapes.h5\")"
......@@ -266,7 +269,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.2"
}
},
"nbformat": 4,
......@@ -51,8 +51,11 @@
"# Directory to save logs and trained model\n",
"MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n",
"\n",
"# Path to COCO trained weights\n",
"COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")"
"# Local path to trained weights file\n",
"COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n",
"# Download COCO trained weights from Releases if needed\n",
"if not os.path.exists(COCO_MODEL_PATH):\n",
" utils.download_trained_weights(COCO_MODEL_PATH)"
]
},
{
......@@ -1024,7 +1027,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.2"
}
},
"nbformat": 4,
......
......@@ -16,6 +16,11 @@ import tensorflow as tf
import scipy.misc
import skimage.color
import skimage.io
import urllib.request
import shutil
# URL from which to download the latest COCO trained weights
COCO_MODEL_URL = "https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5"
############################################################
......@@ -688,3 +693,16 @@ def batch_slice(inputs, graph_fn, batch_size, names=None):
result = result[0]
return result
def download_trained_weights(coco_model_path, verbose=1):
"""Download COCO trained weights from Releases.
coco_model_path: local path of COCO trained weights
"""
if verbose > 0:
print("Downloading pretrained model to " + coco_model_path + " ...")
with urllib.request.urlopen(COCO_MODEL_URL) as resp, open(coco_model_path, 'wb') as out:
shutil.copyfileobj(resp, out)
if verbose > 0:
print("... done downloading pretrained model!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册