提交 4968e6c2 编写于 作者: C chenguowei01

add image_shape argparse

上级 b96c8474
...@@ -65,7 +65,8 @@ python train.py --model_type HumanSegMobile \ ...@@ -65,7 +65,8 @@ python train.py --model_type HumanSegMobile \
--pretrained_weights pretrained_weights/humanseg_mobile \ --pretrained_weights pretrained_weights/humanseg_mobile \
--batch_size 8 \ --batch_size 8 \
--learning_rate 0.001 \ --learning_rate 0.001 \
--num_epochs 10 --num_epochs 10 \
--image_shape 192 192
``` ```
其中参数含义如下: 其中参数含义如下:
* `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite * `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
...@@ -77,6 +78,7 @@ python train.py --model_type HumanSegMobile \ ...@@ -77,6 +78,7 @@ python train.py --model_type HumanSegMobile \
* `--batch_size`: 批大小 * `--batch_size`: 批大小
* `--learning_rate`: 初始学习率 * `--learning_rate`: 初始学习率
* `--num_epochs`: 训练轮数 * `--num_epochs`: 训练轮数
* `--image_shape`: 网络输入图像大小(w, h)
更多命令行帮助可运行下述命令进行查看: 更多命令行帮助可运行下述命令进行查看:
```bash ```bash
...@@ -90,24 +92,28 @@ python train.py --help ...@@ -90,24 +92,28 @@ python train.py --help
```bash ```bash
python val.py --model_dir output/best_model \ python val.py --model_dir output/best_model \
--data_dir data/mini_supervisely \ --data_dir data/mini_supervisely \
--val_list data/mini_supervisely/val.txt --val_list data/mini_supervisely/val.txt \
--image_shape 192 192
``` ```
其中参数含义如下: 其中参数含义如下:
* `--model_dir`: 模型路径 * `--model_dir`: 模型路径
* `--data_dir`: 数据集路径 * `--data_dir`: 数据集路径
* `--val_list`: 验证集列表路径 * `--val_list`: 验证集列表路径
* `--image_shape`: 网络输入图像大小(w, h)
## 预测 ## 预测
使用下述命令进行预测 使用下述命令进行预测
```bash ```bash
python infer.py --model_dir output/best_model \ python infer.py --model_dir output/best_model \
--data_dir data/mini_supervisely \ --data_dir data/mini_supervisely \
--test_list data/mini_supervisely/test.txt --test_list data/mini_supervisely/test.txt \
--image_shape 192 192
``` ```
其中参数含义如下: 其中参数含义如下:
* `--model_dir`: 模型路径 * `--model_dir`: 模型路径
* `--data_dir`: 数据集路径 * `--data_dir`: 数据集路径
* `--test_list`: 测试集列表路径 * `--test_list`: 测试集列表路径
* `--image_shape`: 网络输入图像大小(w, h)
## 模型导出 ## 模型导出
```bash ```bash
...@@ -124,13 +130,15 @@ python export.py --model_dir output/best_model \ ...@@ -124,13 +130,15 @@ python export.py --model_dir output/best_model \
python quant_offline.py --model_dir output/best_model \ python quant_offline.py --model_dir output/best_model \
--data_dir data/mini_supervisely \ --data_dir data/mini_supervisely \
--quant_list data/mini_supervisely/val.txt \ --quant_list data/mini_supervisely/val.txt \
--save_dir output/quant_offline --save_dir output/quant_offline \
--image_shape 192 192
``` ```
其中参数含义如下: 其中参数含义如下:
* `--model_dir`: 待量化模型路径 * `--model_dir`: 待量化模型路径
* `--data_dir`: 数据集路径 * `--data_dir`: 数据集路径
* `--quant_list`: 量化数据集列表路径,一般直接选择训练集或验证集 * `--quant_list`: 量化数据集列表路径,一般直接选择训练集或验证集
* `--save_dir`: 量化模型保存路径 * `--save_dir`: 量化模型保存路径
* `--image_shape`: 网络输入图像大小(w, h)
## 在线量化 ## 在线量化
利用float训练模型进行在线量化。 利用float训练模型进行在线量化。
...@@ -143,7 +151,8 @@ python quant_online.py --model_type HumanSegMobile \ ...@@ -143,7 +151,8 @@ python quant_online.py --model_type HumanSegMobile \
--pretrained_weights output/best_model \ --pretrained_weights output/best_model \
--batch_size 2 \ --batch_size 2 \
--learning_rate 0.001 \ --learning_rate 0.001 \
--num_epochs 2 --num_epochs 2 \
--image_shape 192 192
``` ```
其中参数含义如下: 其中参数含义如下:
* `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite * `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
...@@ -155,3 +164,4 @@ python quant_online.py --model_type HumanSegMobile \ ...@@ -155,3 +164,4 @@ python quant_online.py --model_type HumanSegMobile \
* `--batch_size`: 批大小 * `--batch_size`: 批大小
* `--learning_rate`: 初始学习率 * `--learning_rate`: 初始学习率
* `--num_epochs`: 训练轮数 * `--num_epochs`: 训练轮数
* `--image_shape`: 网络输入图像大小(w, h)
...@@ -34,6 +34,13 @@ def parse_args(): ...@@ -34,6 +34,13 @@ def parse_args():
help='The directory for saving the inference results', help='The directory for saving the inference results',
type=str, type=str,
default='./output/result') default='./output/result')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args() return parser.parse_args()
...@@ -45,7 +52,7 @@ def mkdir(path): ...@@ -45,7 +52,7 @@ def mkdir(path):
def infer(args): def infer(args):
test_transforms = transforms.Compose( test_transforms = transforms.Compose(
[transforms.Resize((192, 192)), [transforms.Resize(args.image_shape),
transforms.Normalize()]) transforms.Normalize()])
model = models.load_model(args.model_dir) model = models.load_model(args.model_dir)
added_saveed_path = osp.join(args.save_dir, 'added') added_saveed_path = osp.join(args.save_dir, 'added')
......
...@@ -42,12 +42,19 @@ def parse_args(): ...@@ -42,12 +42,19 @@ def parse_args():
help='The directory for saving the quant model', help='The directory for saving the quant model',
type=str, type=str,
default='./output/quant_offline') default='./output/quant_offline')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args() return parser.parse_args()
def evaluate(args): def evaluate(args):
eval_transforms = transforms.Compose( eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)), [transforms.Resize(args.image_shape),
transforms.Normalize()]) transforms.Normalize()])
eval_dataset = Dataset( eval_dataset = Dataset(
......
...@@ -73,6 +73,13 @@ def parse_args(): ...@@ -73,6 +73,13 @@ def parse_args():
help='The interval epochs for save a model snapshot', help='The interval epochs for save a model snapshot',
type=int, type=int,
default=1) default=1)
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args() return parser.parse_args()
...@@ -80,12 +87,12 @@ def parse_args(): ...@@ -80,12 +87,12 @@ def parse_args():
def train(args): def train(args):
train_transforms = transforms.Compose([ train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.Resize((192, 192)), transforms.Resize(args.image_shape),
transforms.Normalize() transforms.Normalize()
]) ])
eval_transforms = transforms.Compose( eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)), [transforms.Resize(args.image_shape),
transforms.Normalize()]) transforms.Normalize()])
train_dataset = Dataset( train_dataset = Dataset(
......
...@@ -43,6 +43,13 @@ def parse_args(): ...@@ -43,6 +43,13 @@ def parse_args():
help='Number of classes', help='Number of classes',
type=int, type=int,
default=2) default=2)
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
parser.add_argument( parser.add_argument(
'--num_epochs', '--num_epochs',
dest='num_epochs', dest='num_epochs',
...@@ -91,13 +98,13 @@ def parse_args(): ...@@ -91,13 +98,13 @@ def parse_args():
def train(args): def train(args):
train_transforms = transforms.Compose([ train_transforms = transforms.Compose([
transforms.Resize(args.image_shape),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.Resize((192, 192)),
transforms.Normalize() transforms.Normalize()
]) ])
eval_transforms = transforms.Compose( eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)), [transforms.Resize(args.image_shape),
transforms.Normalize()]) transforms.Normalize()])
train_dataset = Dataset( train_dataset = Dataset(
......
...@@ -29,12 +29,19 @@ def parse_args(): ...@@ -29,12 +29,19 @@ def parse_args():
help='Mini batch size', help='Mini batch size',
type=int, type=int,
default=128) default=128)
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args() return parser.parse_args()
def evaluate(args): def evaluate(args):
eval_transforms = transforms.Compose( eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)), [transforms.Resize(args.image_shape),
transforms.Normalize()]) transforms.Normalize()])
eval_dataset = Dataset( eval_dataset = Dataset(
......
...@@ -29,6 +29,13 @@ def parse_args(): ...@@ -29,6 +29,13 @@ def parse_args():
help='The directory for saving the inference results', help='The directory for saving the inference results',
type=str, type=str,
default='./output') default='./output')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args() return parser.parse_args()
...@@ -60,9 +67,8 @@ def recover(img, im_info): ...@@ -60,9 +67,8 @@ def recover(img, im_info):
def video_infer(args): def video_infer(args):
resize_h = args.image_shape[1]
resize_h = 192 resize_w = args.image_shape[0]
resize_w = 192
test_transforms = transforms.Compose( test_transforms = transforms.Compose(
[transforms.Resize((resize_w, resize_h)), [transforms.Resize((resize_w, resize_h)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册