未验证 提交 353e7529 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #161 from LielinJiang/fix_bug_0211

fix nas python3 bug and update distill pretrained model
......@@ -69,7 +69,10 @@ distill_loss = l2_loss('teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_0.tmp
在该脚本中定义了teacher_model和student_model,用teacher_model的输出指导student_model的训练
### 执行示例
如下命令启动训练,每间隔```cfg.TRAIN.SNAPSHOT_EPOCH```会进行一次评估。
下载teacher的预训练模型和student的预训练模型, 替换如下命令中的```your_tearcher_pretrained_model_path``````your_student_pretrained_model```
执行如下命令启动训练,每间隔```cfg.TRAIN.SNAPSHOT_EPOCH```会进行一次评估。
```shell
CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch ./slim/distillation/train_distill.py \
......@@ -77,7 +80,9 @@ python -m paddle.distributed.launch ./slim/distillation/train_distill.py \
--teacher_cfg ./slim/distillation/cityscape_teacher.yaml \
--use_gpu \
--use_mpio \
--do_eval
--do_eval \
SLIM.KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR your_tearcher_pretrained_model_path \
TRAIN.PRETRAINED_MODEL_DIR your_student_pretrained_model
```
## 评估预测
......
......@@ -49,7 +49,7 @@ TEST:
TEST_MODEL: "snapshots/cityscape_v5/final/"
TRAIN:
MODEL_SAVE_DIR: "snapshots/cityscape_mbv2_kd_e100_1/"
PRETRAINED_MODEL_DIR: u"/workspace/pretrained_models/mobilenet_cityscapes"
PRETRAINED_MODEL_DIR: u"pretrained_model/mobilenet_cityscapes"
SNAPSHOT_EPOCH: 5
SYNC_BATCH_NORM: True
SOLVER:
......
......@@ -61,5 +61,5 @@ SOLVER:
SLIM:
KNOWLEDGE_DISTILL_IS_TEACHER: True
KNOWLEDGE_DISTILL: True
KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR: "/workspace/pretrained_models/xception65_bn_cityscapes"
KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR: "pretrained_model/xception65_bn_cityscapes"
......@@ -46,7 +46,7 @@ SLIM:
## 训练与评估
执行以下命令,边训练边评估
```shell
python -u ./slim/nas/train_nas.py --log_steps 10 --cfg configs/deeplabv3p_mobilenetv2_cityscapes.yaml --use_gpu --use_mpio \
CUDA_VISIBLE_DEVICES=0 python -u ./slim/nas/train_nas.py --log_steps 10 --cfg configs/deeplabv3p_mobilenetv2_cityscapes.yaml --use_gpu --use_mpio \
SLIM.NAS_PORT 23333 \
SLIM.NAS_ADDRESS "" \
SLIM.NAS_SEARCH_STEPS 2 \
......
......@@ -180,7 +180,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
c=int(c * self.scale),
n=n,
s=s,
k=k,
k=int(k),
name='mobilenetv2_conv' + str(i))
in_c = int(c * self.scale)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册