未验证 提交 42e7a737 编写于 作者: L LielinJiang 提交者: GitHub

release a prune model, refine code (#158)

上级 14473532
......@@ -26,7 +26,6 @@ import sys
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
SEG_PATH = os.path.join(LOCAL_PATH, "../../", "pdseg")
sys.path.append(SEG_PATH)
sys.path.append('/workspace/codes/PaddleSlim1')
import time
import argparse
......
......@@ -49,3 +49,10 @@ CUDA_VISIBLE_DEVICES=0
python -u ./slim/prune/eval_prune.py --cfg configs/cityscape_fast_scnn.yaml --use_gpu --use_mpio \
TEST.TEST_MODEL your_trained_model \
```
## 5. 模型
| 模型 | 数据集合 | 下载地址 |剪裁方法| flops | mIoU on val|
|---|---|---|---|---|---|
| Fast-SCNN/bn | Cityscapes |[fast_scnn_cityscapes.tar](https://paddleseg.bj.bcebos.com/models/fast_scnn_cityscape.tar) | 无 | 7.21g | 0.6964 |
| Fast-SCNN/bn | Cityscapes |[fast_scnn_cityscapes-uniform-51.tar](https://paddleseg.bj.bcebos.com/models/fast_scnn_cityscape-uniform-51.tar) | uniform | 3.54g | 0.6990 |
......@@ -26,7 +26,6 @@ import sys
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
SEG_PATH = os.path.join(LOCAL_PATH, "../../", "pdseg")
sys.path.append(SEG_PATH)
sys.path.append('/workspace/codes/PaddleSlim1')
import time
import argparse
......@@ -124,7 +123,6 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if ckpt_dir is not None:
print('load test model:', ckpt_dir)
load_model(exe, test_prog, ckpt_dir)
#fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
# Use streaming confusion matrix to calculate mean_iou
np.set_printoptions(
......
......@@ -146,11 +146,6 @@ def save_prune_checkpoint(exe, program, ckpt_name):
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
# save_vars(
# exe,
# ckpt_dir,
# program,
# vars=list(filter(fluid.io.is_persistable, program.list_vars())))
save_model(exe, program, ckpt_dir)
return ckpt_dir
......@@ -270,26 +265,9 @@ def train(cfg):
print_info("Sync BatchNorm strategy will not be effective if GPU device"
" count <= 1")
####get mobilenetV1 prune parameters####
# pruned_params = []
# exclude_names = ['aux_layer_lower/logit/weights', 'aux_layer_lower/logit/biases',
# 'aux_layer_higher/logit/weights', 'aux_layer_higher/logit/biases',
# 'classifier/weights', 'classifier/biases']
# for x in train_prog.list_vars():
# if isinstance(x, fluid.framework.Parameter):
# # if x.name not in exclude_names:
# if x.name not in exclude_names and "weights" in x.name and "depthwise" not in x.name and "dwise" not in x.name and x.name not in [
# "classifier/dsconv2/pointwise/weights", "classifier/weights", "fc_weights"]:
# pruned_params.append(x.name)
# print(x.name)
# print('to prune paramter number:', len(pruned_params))
# print('listtttt:', ','.join(pruned_params))
# pruned_ratios = [0.1] * len(pruned_params)
########################################
pruned_params = cfg.SLIM.PRUNE_PARAMS.strip().split(',')
#pruned_params = [str(x) for x in pruned_params]
print('paramssss:', pruned_params)
pruned_ratios = cfg.SLIM.PRUNE_RATIOS
if isinstance(pruned_ratios, float):
pruned_ratios = [pruned_ratios] * len(pruned_params)
elif isinstance(pruned_ratios, (list, tuple)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册