未验证 提交 ebc9bb93 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #160 from LielinJiang/fix_bug_from_qa

Fix some typos
...@@ -49,7 +49,7 @@ print(teacher_vars) ...@@ -49,7 +49,7 @@ print(teacher_vars)
```bash ```bash
# student model # student model
bilinear_interp_1.tmp_0 bilinear_interp_0.tmp_0
# teacher model # teacher model
bilinear_interp_2.tmp_0 bilinear_interp_2.tmp_0
``` ```
...@@ -58,7 +58,7 @@ bilinear_interp_2.tmp_0 ...@@ -58,7 +58,7 @@ bilinear_interp_2.tmp_0
它们形状两两相同,且分别处于两个网络的输出部分。所以,我们用`l2_loss`对这几个特征图两两对应添加蒸馏loss。需要注意的是,teacher的Variable在merge过程中被自动添加了一个`name_prefix`,所以这里也需要加上这个前缀`"teacher_"`,merge过程请参考[蒸馏API文档](https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/#merge) 它们形状两两相同,且分别处于两个网络的输出部分。所以,我们用`l2_loss`对这几个特征图两两对应添加蒸馏loss。需要注意的是,teacher的Variable在merge过程中被自动添加了一个`name_prefix`,所以这里也需要加上这个前缀`"teacher_"`,merge过程请参考[蒸馏API文档](https://paddlepaddle.github.io/PaddleSlim/api/single_distiller_api/#merge)
```python ```python
distill_loss = l2_loss('teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_1.tmp_0') distill_loss = l2_loss('teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_0.tmp_0')
``` ```
我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有`FSP_loss`, `L2_loss`, `softmax_with_cross_entropy_loss` 以及自定义的任何loss。 我们也可以根据上述操作为蒸馏策略选择其他loss,PaddleSlim支持的有`FSP_loss`, `L2_loss`, `softmax_with_cross_entropy_loss` 以及自定义的任何loss。
...@@ -72,9 +72,9 @@ distill_loss = l2_loss('teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_1.tmp ...@@ -72,9 +72,9 @@ distill_loss = l2_loss('teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_1.tmp
如下命令启动训练,每间隔```cfg.TRAIN.SNAPSHOT_EPOCH```会进行一次评估。 如下命令启动训练,每间隔```cfg.TRAIN.SNAPSHOT_EPOCH```会进行一次评估。
```shell ```shell
CUDA_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch ./slim/distill/train.py \ python -m paddle.distributed.launch ./slim/distillation/train_distill.py \
--log_steps 10 --cfg ./slim/distill/cityscape_fast_scnn.yaml \ --log_steps 10 --cfg ./slim/distillation/cityscape.yaml \
--teacher_cfg ./slim/distill/cityscape_teacher.yaml \ --teacher_cfg ./slim/distillation/cityscape_teacher.yaml \
--use_gpu \ --use_gpu \
--use_mpio \ --use_mpio \
--do_eval --do_eval
......
...@@ -157,7 +157,7 @@ def export_preprocess(image): ...@@ -157,7 +157,7 @@ def export_preprocess(image):
def build_model(main_prog=None, start_prog=None, phase=ModelPhase.TRAIN, **kwargs): def build_model(main_prog=None, start_prog=None, phase=ModelPhase.TRAIN, **kwargs):
print('debugggggggggg')
if not ModelPhase.is_valid_phase(phase): if not ModelPhase.is_valid_phase(phase):
raise ValueError("ModelPhase {} is not valid!".format(phase)) raise ValueError("ModelPhase {} is not valid!".format(phase))
if ModelPhase.is_train(phase): if ModelPhase.is_train(phase):
...@@ -176,7 +176,6 @@ def build_model(main_prog=None, start_prog=None, phase=ModelPhase.TRAIN, **kwarg ...@@ -176,7 +176,6 @@ def build_model(main_prog=None, start_prog=None, phase=ModelPhase.TRAIN, **kwarg
# 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程 # 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程
# 预测部署时只须对输入图像增加batch_size维度即可 # 预测部署时只须对输入图像增加batch_size维度即可
if cfg.SLIM.KNOWLEDGE_DISTILL_IS_TEACHER: if cfg.SLIM.KNOWLEDGE_DISTILL_IS_TEACHER:
print('teacher input:')
image = main_prog.global_block()._clone_variable(kwargs['image'], image = main_prog.global_block()._clone_variable(kwargs['image'],
force_persistable=False) force_persistable=False)
label = main_prog.global_block()._clone_variable(kwargs['label'], label = main_prog.global_block()._clone_variable(kwargs['label'],
......
...@@ -278,8 +278,6 @@ def train(cfg): ...@@ -278,8 +278,6 @@ def train(cfg):
label=grts, mask=masks) label=grts, mask=masks)
exe.run(teacher_startup_program) exe.run(teacher_startup_program)
# assert FLAGS.teacher_pretrained, "teacher_pretrained should be set"
# checkpoint.load_params(exe, teacher_program, FLAGS.teacher_pretrained)
teacher_program = teacher_program.clone(for_test=True) teacher_program = teacher_program.clone(for_test=True)
ckpt_dir = cfg.SLIM.KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR ckpt_dir = cfg.SLIM.KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR
...@@ -295,14 +293,14 @@ def train(cfg): ...@@ -295,14 +293,14 @@ def train(cfg):
'mask': 'mask', 'mask': 'mask',
} }
merge(teacher_program, fluid.default_main_program(), data_name_map, place) merge(teacher_program, fluid.default_main_program(), data_name_map, place)
distill_pairs = [['teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_1.tmp_0']] distill_pairs = [['teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_0.tmp_0']]
def distill(pairs, weight): def distill(pairs, weight):
""" """
Add 3 pairs of distillation losses, each pair of feature maps is the Add 3 pairs of distillation losses, each pair of feature maps is the
input of teacher and student's yolov3_loss respectively input of teacher and student's yolov3_loss respectively
""" """
loss = l2_loss(pairs[0][0], pairs[0][1], masks) loss = l2_loss(pairs[0][0], pairs[0][1])
weighted_loss = loss * weight weighted_loss = loss * weight
return weighted_loss return weighted_loss
......
...@@ -46,7 +46,7 @@ SLIM: ...@@ -46,7 +46,7 @@ SLIM:
## 训练与评估 ## 训练与评估
执行以下命令,边训练边评估 执行以下命令,边训练边评估
```shell ```shell
python -u ./slim/nas/train.py --log_steps 10 --cfg configs/cityscape.yaml --use_gpu --use_mpio \ python -u ./slim/nas/train_nas.py --log_steps 10 --cfg configs/deeplabv3p_mobilenetv2_cityscapes.yaml --use_gpu --use_mpio \
SLIM.NAS_PORT 23333 \ SLIM.NAS_PORT 23333 \
SLIM.NAS_ADDRESS "" \ SLIM.NAS_ADDRESS "" \
SLIM.NAS_SEARCH_STEPS 2 \ SLIM.NAS_SEARCH_STEPS 2 \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册