提交 c8bb1af5 编写于 作者: P Phil Ferriere 提交者: Waleed Abdulla

Automatically download trained model file

上级 2aa62a39
...@@ -45,10 +45,11 @@ ...@@ -45,10 +45,11 @@
"# Directory to save logs and trained model\n", "# Directory to save logs and trained model\n",
"MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n", "MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n",
"\n", "\n",
"# Path to trained weights file\n", "# Local path to trained weights file\n",
"# Download this file and place in the root of your \n",
"# project (See README file for details)\n",
"COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n", "COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n",
"# Download COCO trained weights from Releases if needed\n",
"if not os.path.exists(COCO_MODEL_PATH):\n",
" utils.download_trained_weights(COCO_MODEL_PATH)\n",
"\n", "\n",
"# Directory of images to run detection on\n", "# Directory of images to run detection on\n",
"IMAGE_DIR = os.path.join(ROOT_DIR, \"images\")" "IMAGE_DIR = os.path.join(ROOT_DIR, \"images\")"
...@@ -144,6 +145,7 @@ ...@@ -144,6 +145,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"metadata": { "metadata": {
"collapsed": true,
"scrolled": false "scrolled": false
}, },
"outputs": [], "outputs": [],
...@@ -282,7 +284,7 @@ ...@@ -282,7 +284,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.5.2" "version": "3.6.2"
} }
}, },
"nbformat": 4, "nbformat": 4,
...@@ -49,10 +49,11 @@ ...@@ -49,10 +49,11 @@
"# Directory to save logs and trained model\n", "# Directory to save logs and trained model\n",
"MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n", "MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n",
"\n", "\n",
"# Path to trained weights file\n", "# Local path to trained weights file\n",
"# Download this file and place in the root of your \n",
"# project (See README file for details)\n",
"COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n", "COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n",
"# Download COCO trained weights from Releases if needed\n",
"if not os.path.exists(COCO_MODEL_PATH):\n",
" utils.download_trained_weights(COCO_MODEL_PATH)\n",
"\n", "\n",
"# Path to Shapes trained weights\n", "# Path to Shapes trained weights\n",
"SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_shapes.h5\")" "SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_shapes.h5\")"
...@@ -1377,7 +1378,7 @@ ...@@ -1377,7 +1378,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.5.2" "version": "3.6.2"
} }
}, },
"nbformat": 4, "nbformat": 4,
...@@ -44,8 +44,11 @@ ...@@ -44,8 +44,11 @@
"# Directory to save logs and trained model\n", "# Directory to save logs and trained model\n",
"MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n", "MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n",
"\n", "\n",
"# Path to COCO trained weights\n", "# Local path to trained weights file\n",
"COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n", "COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n",
"# Download COCO trained weights from Releases if needed\n",
"if not os.path.exists(COCO_MODEL_PATH):\n",
" utils.download_trained_weights(COCO_MODEL_PATH)\n",
"\n", "\n",
"# Path to Shapes trained weights\n", "# Path to Shapes trained weights\n",
"SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_shapes.h5\")" "SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_shapes.h5\")"
...@@ -266,7 +269,7 @@ ...@@ -266,7 +269,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.5.2" "version": "3.6.2"
} }
}, },
"nbformat": 4, "nbformat": 4,
...@@ -51,8 +51,11 @@ ...@@ -51,8 +51,11 @@
"# Directory to save logs and trained model\n", "# Directory to save logs and trained model\n",
"MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n", "MODEL_DIR = os.path.join(ROOT_DIR, \"logs\")\n",
"\n", "\n",
"# Path to COCO trained weights\n", "# Local path to trained weights file\n",
"COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")" "COCO_MODEL_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n",
"# Download COCO trained weights from Releases if needed\n",
"if not os.path.exists(COCO_MODEL_PATH):\n",
" utils.download_trained_weights(COCO_MODEL_PATH)"
] ]
}, },
{ {
...@@ -1024,7 +1027,7 @@ ...@@ -1024,7 +1027,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.5.2" "version": "3.6.2"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -16,6 +16,11 @@ import tensorflow as tf ...@@ -16,6 +16,11 @@ import tensorflow as tf
import scipy.misc import scipy.misc
import skimage.color import skimage.color
import skimage.io import skimage.io
import urllib.request
import shutil
# URL from which to download the latest COCO trained weights
COCO_MODEL_URL = "https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5"
############################################################ ############################################################
...@@ -688,3 +693,16 @@ def batch_slice(inputs, graph_fn, batch_size, names=None): ...@@ -688,3 +693,16 @@ def batch_slice(inputs, graph_fn, batch_size, names=None):
result = result[0] result = result[0]
return result return result
def download_trained_weights(coco_model_path, verbose=1):
"""Download COCO trained weights from Releases.
coco_model_path: local path of COCO trained weights
"""
if verbose > 0:
print("Downloading pretrained model to " + coco_model_path + " ...")
with urllib.request.urlopen(COCO_MODEL_URL) as resp, open(coco_model_path, 'wb') as out:
shutil.copyfileobj(resp, out)
if verbose > 0:
print("... done downloading pretrained model!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册