diff --git a/contrib/HumanSeg/README.md b/contrib/HumanSeg/README.md index 1f8ab38055072b6eeec8f6d40ba36c5bbe26bfea..c5ba095749e592bbb3d866935bfde4a904323862 100644 --- a/contrib/HumanSeg/README.md +++ b/contrib/HumanSeg/README.md @@ -65,7 +65,8 @@ python train.py --model_type HumanSegMobile \ --pretrained_weights pretrained_weights/humanseg_mobile \ --batch_size 8 \ --learning_rate 0.001 \ ---num_epochs 10 +--num_epochs 10 \ +--image_shape 192 192 ``` 其中参数含义如下: * `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite @@ -77,6 +78,7 @@ python train.py --model_type HumanSegMobile \ * `--batch_size`: 批大小 * `--learning_rate`: 初始学习率 * `--num_epochs`: 训练轮数 +* `--image_shape`: 网络输入图像大小(w, h) 更多命令行帮助可运行下述命令进行查看: ```bash @@ -90,24 +92,28 @@ python train.py --help ```bash python val.py --model_dir output/best_model \ --data_dir data/mini_supervisely \ ---val_list data/mini_supervisely/val.txt +--val_list data/mini_supervisely/val.txt \ +--image_shape 192 192 ``` 其中参数含义如下: * `--model_dir`: 模型路径 * `--data_dir`: 数据集路径 * `--val_list`: 验证集列表路径 +* `--image_shape`: 网络输入图像大小(w, h) ## 预测 使用下述命令进行预测 ```bash python infer.py --model_dir output/best_model \ --data_dir data/mini_supervisely \ ---test_list data/mini_supervisely/test.txt +--test_list data/mini_supervisely/test.txt \ +--image_shape 192 192 ``` 其中参数含义如下: * `--model_dir`: 模型路径 * `--data_dir`: 数据集路径 * `--test_list`: 测试集列表路径 +* `--image_shape`: 网络输入图像大小(w, h) ## 模型导出 ```bash @@ -124,13 +130,15 @@ python export.py --model_dir output/best_model \ python quant_offline.py --model_dir output/best_model \ --data_dir data/mini_supervisely \ --quant_list data/mini_supervisely/val.txt \ ---save_dir output/quant_offline +--save_dir output/quant_offline \ +--image_shape 192 192 ``` 其中参数含义如下: * `--model_dir`: 待量化模型路径 * `--data_dir`: 数据集路径 * `--quant_list`: 量化数据集列表路径,一般直接选择训练集或验证集 * `--save_dir`: 量化模型保存路径 +* `--image_shape`: 网络输入图像大小(w, h) ## 在线量化 利用float训练模型进行在线量化。 @@ -143,7 +151,8 @@ python quant_online.py --model_type HumanSegMobile \ --pretrained_weights output/best_model \ --batch_size 2 \ --learning_rate 0.001 \ ---num_epochs 2 +--num_epochs 2 \ +--image_shape 192 192 ``` 其中参数含义如下: * `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite @@ -155,3 +164,4 @@ python quant_online.py --model_type HumanSegMobile \ * `--batch_size`: 批大小 * `--learning_rate`: 初始学习率 * `--num_epochs`: 训练轮数 +* `--image_shape`: 网络输入图像大小(w, h) diff --git a/contrib/HumanSeg/infer.py b/contrib/HumanSeg/infer.py index 3e5ac8363239d9ae0ee8afdb32f6750b74d605b6..96aabac6c44c164504f6626accfadd36983219e5 100644 --- a/contrib/HumanSeg/infer.py +++ b/contrib/HumanSeg/infer.py @@ -34,6 +34,13 @@ def parse_args(): help='The directory for saving the inference results', type=str, default='./output/result') + parser.add_argument( + "--image_shape", + dest="image_shape", + help="The image shape for net inputs.", + nargs=2, + default=[192, 192], + type=int) return parser.parse_args() @@ -45,7 +52,7 @@ def mkdir(path): def infer(args): test_transforms = transforms.Compose( - [transforms.Resize((192, 192)), + [transforms.Resize(args.image_shape), transforms.Normalize()]) model = models.load_model(args.model_dir) added_saveed_path = osp.join(args.save_dir, 'added') diff --git a/contrib/HumanSeg/quant_offline.py b/contrib/HumanSeg/quant_offline.py index c2d51e0c3995bce65f89a7cc44a0a12bc7ace3b2..92a393f07bd2b70fc7df658290abf440f3069752 100644 --- a/contrib/HumanSeg/quant_offline.py +++ b/contrib/HumanSeg/quant_offline.py @@ -42,12 +42,19 @@ def parse_args(): help='The directory for saving the quant model', type=str, default='./output/quant_offline') + parser.add_argument( + "--image_shape", + dest="image_shape", + help="The image shape for net inputs.", + nargs=2, + default=[192, 192], + type=int) return parser.parse_args() def evaluate(args): eval_transforms = transforms.Compose( - [transforms.Resize((192, 192)), + [transforms.Resize(args.image_shape), transforms.Normalize()]) eval_dataset = Dataset( diff --git a/contrib/HumanSeg/quant_online.py b/contrib/HumanSeg/quant_online.py index fbc1d026ae0249d6f326c4e01aec3a803d0d8925..04eea4d3d9f357897e300da87297a8f6c9515e06 100644 --- a/contrib/HumanSeg/quant_online.py +++ b/contrib/HumanSeg/quant_online.py @@ -73,6 +73,13 @@ def parse_args(): help='The interval epochs for save a model snapshot', type=int, default=1) + parser.add_argument( + "--image_shape", + dest="image_shape", + help="The image shape for net inputs.", + nargs=2, + default=[192, 192], + type=int) return parser.parse_args() @@ -80,12 +87,12 @@ def parse_args(): def train(args): train_transforms = transforms.Compose([ transforms.RandomHorizontalFlip(), - transforms.Resize((192, 192)), + transforms.Resize(args.image_shape), transforms.Normalize() ]) eval_transforms = transforms.Compose( - [transforms.Resize((192, 192)), + [transforms.Resize(args.image_shape), transforms.Normalize()]) train_dataset = Dataset( diff --git a/contrib/HumanSeg/train.py b/contrib/HumanSeg/train.py index c74639627a0a49857b59947edb77f4c071bf4894..65e66ae1dcd07f65b062744216ed6cbfc85cad40 100644 --- a/contrib/HumanSeg/train.py +++ b/contrib/HumanSeg/train.py @@ -43,6 +43,13 @@ def parse_args(): help='Number of classes', type=int, default=2) + parser.add_argument( + "--image_shape", + dest="image_shape", + help="The image shape for net inputs.", + nargs=2, + default=[192, 192], + type=int) parser.add_argument( '--num_epochs', dest='num_epochs', @@ -91,13 +98,13 @@ def parse_args(): def train(args): train_transforms = transforms.Compose([ + transforms.Resize(args.image_shape), transforms.RandomHorizontalFlip(), - transforms.Resize((192, 192)), transforms.Normalize() ]) eval_transforms = transforms.Compose( - [transforms.Resize((192, 192)), + [transforms.Resize(args.image_shape), transforms.Normalize()]) train_dataset = Dataset( diff --git a/contrib/HumanSeg/val.py b/contrib/HumanSeg/val.py index ffceecdc40746dd471817ab457011f1948f14bc3..cecdb5d5c579b22688a092d700863737ec35a13d 100644 --- a/contrib/HumanSeg/val.py +++ b/contrib/HumanSeg/val.py @@ -29,12 +29,19 @@ def parse_args(): help='Mini batch size', type=int, default=128) + parser.add_argument( + "--image_shape", + dest="image_shape", + help="The image shape for net inputs.", + nargs=2, + default=[192, 192], + type=int) return parser.parse_args() def evaluate(args): eval_transforms = transforms.Compose( - [transforms.Resize((192, 192)), + [transforms.Resize(args.image_shape), transforms.Normalize()]) eval_dataset = Dataset( diff --git a/contrib/HumanSeg/video_infer.py b/contrib/HumanSeg/video_infer.py index bc8ee690e343113a3091293862ee836ac901a379..b248669cf9455e908d2c8dfb98f8edae273f73a9 100644 --- a/contrib/HumanSeg/video_infer.py +++ b/contrib/HumanSeg/video_infer.py @@ -29,6 +29,13 @@ def parse_args(): help='The directory for saving the inference results', type=str, default='./output') + parser.add_argument( + "--image_shape", + dest="image_shape", + help="The image shape for net inputs.", + nargs=2, + default=[192, 192], + type=int) return parser.parse_args() @@ -60,9 +67,8 @@ def recover(img, im_info): def video_infer(args): - - resize_h = 192 - resize_w = 192 + resize_h = args.image_shape[1] + resize_w = args.image_shape[0] test_transforms = transforms.Compose( [transforms.Resize((resize_w, resize_h)),