提交 ff51164e 编写于 作者: L LielinJiang

fix bug from qa

上级 42e7a737
...@@ -71,10 +71,10 @@ distill_loss = l2_loss('teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_1.tmp ...@@ -71,10 +71,10 @@ distill_loss = l2_loss('teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_1.tmp
### 执行示例 ### 执行示例
如下命令启动训练,每间隔```cfg.TRAIN.SNAPSHOT_EPOCH```会进行一次评估。 如下命令启动训练,每间隔```cfg.TRAIN.SNAPSHOT_EPOCH```会进行一次评估。
```shell ```shell
CUDA_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch ./slim/distill/train.py \ python -m paddle.distributed.launch ./slim/distillation/train_distill.py \
--log_steps 10 --cfg ./slim/distill/cityscape_fast_scnn.yaml \ --log_steps 10 --cfg ./slim/distillation/cityscape.yaml \
--teacher_cfg ./slim/distill/cityscape_teacher.yaml \ --teacher_cfg ./slim/distillation/cityscape_teacher.yaml \
--use_gpu \ --use_gpu \
--use_mpio \ --use_mpio \
--do_eval --do_eval
......
...@@ -157,7 +157,7 @@ def export_preprocess(image): ...@@ -157,7 +157,7 @@ def export_preprocess(image):
def build_model(main_prog=None, start_prog=None, phase=ModelPhase.TRAIN, **kwargs): def build_model(main_prog=None, start_prog=None, phase=ModelPhase.TRAIN, **kwargs):
print('debugggggggggg')
if not ModelPhase.is_valid_phase(phase): if not ModelPhase.is_valid_phase(phase):
raise ValueError("ModelPhase {} is not valid!".format(phase)) raise ValueError("ModelPhase {} is not valid!".format(phase))
if ModelPhase.is_train(phase): if ModelPhase.is_train(phase):
...@@ -176,7 +176,6 @@ def build_model(main_prog=None, start_prog=None, phase=ModelPhase.TRAIN, **kwarg ...@@ -176,7 +176,6 @@ def build_model(main_prog=None, start_prog=None, phase=ModelPhase.TRAIN, **kwarg
# 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程 # 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程
# 预测部署时只须对输入图像增加batch_size维度即可 # 预测部署时只须对输入图像增加batch_size维度即可
if cfg.SLIM.KNOWLEDGE_DISTILL_IS_TEACHER: if cfg.SLIM.KNOWLEDGE_DISTILL_IS_TEACHER:
print('teacher input:')
image = main_prog.global_block()._clone_variable(kwargs['image'], image = main_prog.global_block()._clone_variable(kwargs['image'],
force_persistable=False) force_persistable=False)
label = main_prog.global_block()._clone_variable(kwargs['label'], label = main_prog.global_block()._clone_variable(kwargs['label'],
......
...@@ -23,7 +23,7 @@ import sys ...@@ -23,7 +23,7 @@ import sys
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
SEG_PATH = os.path.join(LOCAL_PATH, "../../", "pdseg") SEG_PATH = os.path.join(LOCAL_PATH, "../../", "pdseg")
sys.path.append(SEG_PATH) sys.path.append(SEG_PATH)
sys.path.append('/workspace/codes/PaddleSlim1')
import argparse import argparse
import pprint import pprint
import random import random
...@@ -278,8 +278,6 @@ def train(cfg): ...@@ -278,8 +278,6 @@ def train(cfg):
label=grts, mask=masks) label=grts, mask=masks)
exe.run(teacher_startup_program) exe.run(teacher_startup_program)
# assert FLAGS.teacher_pretrained, "teacher_pretrained should be set"
# checkpoint.load_params(exe, teacher_program, FLAGS.teacher_pretrained)
teacher_program = teacher_program.clone(for_test=True) teacher_program = teacher_program.clone(for_test=True)
ckpt_dir = cfg.SLIM.KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR ckpt_dir = cfg.SLIM.KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR
...@@ -295,14 +293,14 @@ def train(cfg): ...@@ -295,14 +293,14 @@ def train(cfg):
'mask': 'mask', 'mask': 'mask',
} }
merge(teacher_program, fluid.default_main_program(), data_name_map, place) merge(teacher_program, fluid.default_main_program(), data_name_map, place)
distill_pairs = [['teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_1.tmp_0']] distill_pairs = [['teacher_bilinear_interp_2.tmp_0', 'bilinear_interp_0.tmp_0']]
def distill(pairs, weight): def distill(pairs, weight):
""" """
Add 3 pairs of distillation losses, each pair of feature maps is the Add 3 pairs of distillation losses, each pair of feature maps is the
input of teacher and student's yolov3_loss respectively input of teacher and student's yolov3_loss respectively
""" """
loss = l2_loss(pairs[0][0], pairs[0][1], masks) loss = l2_loss(pairs[0][0], pairs[0][1])
weighted_loss = loss * weight weighted_loss = loss * weight
return weighted_loss return weighted_loss
......
...@@ -46,7 +46,7 @@ SLIM: ...@@ -46,7 +46,7 @@ SLIM:
## 训练与评估 ## 训练与评估
执行以下命令,边训练边评估 执行以下命令,边训练边评估
```shell ```shell
python -u ./slim/nas/train.py --log_steps 10 --cfg configs/cityscape.yaml --use_gpu --use_mpio \ python -u ./slim/nas/train_nas.py --log_steps 10 --cfg configs/deeplabv3p_mobilenetv2_cityscapes.yaml --use_gpu --use_mpio \
SLIM.NAS_PORT 23333 \ SLIM.NAS_PORT 23333 \
SLIM.NAS_ADDRESS "" \ SLIM.NAS_ADDRESS "" \
SLIM.NAS_SEARCH_STEPS 2 \ SLIM.NAS_SEARCH_STEPS 2 \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册