提交 a99bc1a9 编写于 作者: G gengdongjie

bugfix on argpasr for bool

上级 1821e98e
...@@ -50,7 +50,7 @@ MaskRcnn is a two-stage target detection network,This network uses a region prop ...@@ -50,7 +50,7 @@ MaskRcnn is a two-stage target detection network,This network uses a region prop
```shell ```shell
. .
└─MaskRcnn └─maskrcnn
├─README.md ├─README.md
├─scripts ├─scripts
├─run_download_process_data.sh ├─run_download_process_data.sh
...@@ -58,7 +58,7 @@ MaskRcnn is a two-stage target detection network,This network uses a region prop ...@@ -58,7 +58,7 @@ MaskRcnn is a two-stage target detection network,This network uses a region prop
├─run_train.sh ├─run_train.sh
└─run_eval.sh └─run_eval.sh
├─src ├─src
├─MaskRcnn ├─maskrcnn
├─__init__.py ├─__init__.py
├─anchor_generator.py ├─anchor_generator.py
├─bbox_assign_sample.py ├─bbox_assign_sample.py
......
...@@ -24,7 +24,7 @@ from mindspore import context, Tensor ...@@ -24,7 +24,7 @@ from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from src.MaskRcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.config import config from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
......
...@@ -141,8 +141,8 @@ config = ed({ ...@@ -141,8 +141,8 @@ config = ed({
"keep_checkpoint_max": 12, "keep_checkpoint_max": 12,
"save_checkpoint_path": "./checkpoint", "save_checkpoint_path": "./checkpoint",
"mindrecord_dir": "/home/mxw/mask_rcnn/scripts/MindRecord_COCO2017_Train", "mindrecord_dir": "/home/mask_rcnn/MindRecord_COCO2017_Train",
"coco_root": "/home/mxw/coco2017/", "coco_root": "/home/mask_rcnn/coco2017/",
"train_data_type": "train2017", "train_data_type": "train2017",
"val_data_type": "val2017", "val_data_type": "val2017",
"instance_set": "annotations/instances_{}.json", "instance_set": "annotations/instances_{}.json",
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import os import os
import argparse import argparse
import random import random
import ast
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
...@@ -30,7 +31,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net ...@@ -30,7 +31,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD from mindspore.nn import SGD
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from src.MaskRcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
from src.config import config from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
...@@ -41,11 +42,11 @@ np.random.seed(1) ...@@ -41,11 +42,11 @@ np.random.seed(1)
de.config.set_seed(1) de.config.set_seed(1)
parser = argparse.ArgumentParser(description="MaskRcnn training") parser = argparse.ArgumentParser(description="MaskRcnn training")
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, help="If set it true, only create "
"Mindrecord, default is false.") "Mindrecord, default is false.")
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default is false.")
parser.add_argument("--do_train", type=bool, default=True, help="Do train or not, default is true.") parser.add_argument("--do_train", type=ast.literal_eval, default=True, help="Do train or not, default is true.")
parser.add_argument("--do_eval", type=bool, default=False, help="Do eval or not, default is false.") parser.add_argument("--do_eval", type=ast.literal_eval, default=False, help="Do eval or not, default is false.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--pre_trained", type=str, default="", help="Pretrain file path.") parser.add_argument("--pre_trained", type=str, default="", help="Pretrain file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册