未验证 提交 6e54823f 编写于 作者: W wuyefeilin 提交者: GitHub

Merge pull request #243 from wuyefeilin/humanseg

add image_shape argparse
......@@ -65,7 +65,8 @@ python train.py --model_type HumanSegMobile \
--pretrained_weights pretrained_weights/humanseg_mobile \
--batch_size 8 \
--learning_rate 0.001 \
--num_epochs 10
--num_epochs 10 \
--image_shape 192 192
```
其中参数含义如下:
* `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
......@@ -77,6 +78,7 @@ python train.py --model_type HumanSegMobile \
* `--batch_size`: 批大小
* `--learning_rate`: 初始学习率
* `--num_epochs`: 训练轮数
* `--image_shape`: 网络输入图像大小(w, h)
更多命令行帮助可运行下述命令进行查看:
```bash
......@@ -90,24 +92,28 @@ python train.py --help
```bash
python val.py --model_dir output/best_model \
--data_dir data/mini_supervisely \
--val_list data/mini_supervisely/val.txt
--val_list data/mini_supervisely/val.txt \
--image_shape 192 192
```
其中参数含义如下:
* `--model_dir`: 模型路径
* `--data_dir`: 数据集路径
* `--val_list`: 验证集列表路径
* `--image_shape`: 网络输入图像大小(w, h)
## 预测
使用下述命令进行预测
```bash
python infer.py --model_dir output/best_model \
--data_dir data/mini_supervisely \
--test_list data/mini_supervisely/test.txt
--test_list data/mini_supervisely/test.txt \
--image_shape 192 192
```
其中参数含义如下:
* `--model_dir`: 模型路径
* `--data_dir`: 数据集路径
* `--test_list`: 测试集列表路径
* `--image_shape`: 网络输入图像大小(w, h)
## 模型导出
```bash
......@@ -124,13 +130,15 @@ python export.py --model_dir output/best_model \
python quant_offline.py --model_dir output/best_model \
--data_dir data/mini_supervisely \
--quant_list data/mini_supervisely/val.txt \
--save_dir output/quant_offline
--save_dir output/quant_offline \
--image_shape 192 192
```
其中参数含义如下:
* `--model_dir`: 待量化模型路径
* `--data_dir`: 数据集路径
* `--quant_list`: 量化数据集列表路径,一般直接选择训练集或验证集
* `--save_dir`: 量化模型保存路径
* `--image_shape`: 网络输入图像大小(w, h)
## 在线量化
利用float训练模型进行在线量化。
......@@ -143,7 +151,8 @@ python quant_online.py --model_type HumanSegMobile \
--pretrained_weights output/best_model \
--batch_size 2 \
--learning_rate 0.001 \
--num_epochs 2
--num_epochs 2 \
--image_shape 192 192
```
其中参数含义如下:
* `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
......@@ -155,3 +164,4 @@ python quant_online.py --model_type HumanSegMobile \
* `--batch_size`: 批大小
* `--learning_rate`: 初始学习率
* `--num_epochs`: 训练轮数
* `--image_shape`: 网络输入图像大小(w, h)
......@@ -34,6 +34,13 @@ def parse_args():
help='The directory for saving the inference results',
type=str,
default='./output/result')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
......@@ -45,7 +52,7 @@ def mkdir(path):
def infer(args):
test_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
[transforms.Resize(args.image_shape),
transforms.Normalize()])
model = models.load_model(args.model_dir)
added_saveed_path = osp.join(args.save_dir, 'added')
......
......@@ -42,12 +42,19 @@ def parse_args():
help='The directory for saving the quant model',
type=str,
default='./output/quant_offline')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
def evaluate(args):
eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
[transforms.Resize(args.image_shape),
transforms.Normalize()])
eval_dataset = Dataset(
......
......@@ -73,6 +73,13 @@ def parse_args():
help='The interval epochs for save a model snapshot',
type=int,
default=1)
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
......@@ -80,12 +87,12 @@ def parse_args():
def train(args):
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize((192, 192)),
transforms.Resize(args.image_shape),
transforms.Normalize()
])
eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
[transforms.Resize(args.image_shape),
transforms.Normalize()])
train_dataset = Dataset(
......
......@@ -43,6 +43,13 @@ def parse_args():
help='Number of classes',
type=int,
default=2)
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
parser.add_argument(
'--num_epochs',
dest='num_epochs',
......@@ -91,13 +98,13 @@ def parse_args():
def train(args):
train_transforms = transforms.Compose([
transforms.Resize(args.image_shape),
transforms.RandomHorizontalFlip(),
transforms.Resize((192, 192)),
transforms.Normalize()
])
eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
[transforms.Resize(args.image_shape),
transforms.Normalize()])
train_dataset = Dataset(
......
......@@ -29,12 +29,19 @@ def parse_args():
help='Mini batch size',
type=int,
default=128)
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
def evaluate(args):
eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
[transforms.Resize(args.image_shape),
transforms.Normalize()])
eval_dataset = Dataset(
......
......@@ -29,6 +29,13 @@ def parse_args():
help='The directory for saving the inference results',
type=str,
default='./output')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
......@@ -60,9 +67,8 @@ def recover(img, im_info):
def video_infer(args):
resize_h = 192
resize_w = 192
resize_h = args.image_shape[1]
resize_w = args.image_shape[0]
test_transforms = transforms.Compose(
[transforms.Resize((resize_w, resize_h)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册