未验证 提交 943e75fd 编写于 作者: W whs 提交者: GitHub

[cherry-pick]cherry pick fix on ocr and distillation for models-1.4 (#2111)

* Fix get_attention_feeder_for_infer (#2067)

* Fix distillation in slim demo. (#2107)
上级 a9fe5d72
......@@ -31,7 +31,7 @@ def inference(args):
"""OCR inference"""
if args.model == "crnn_ctc":
infer = ctc_infer
get_feeder_data = get_ctc_feeder_data
get_feeder_data = get_ctc_feeder_for_infer
else:
infer = attention_infer
get_feeder_data = get_attention_feeder_for_infer
......@@ -78,7 +78,7 @@ def inference(args):
batch_times = []
iters = 0
for data in infer_reader():
feed_dict = get_feeder_data(data, place, need_label=False)
feed_dict = get_feeder_data(data, place)
if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
break
if iters < args.skip_batch_num:
......
......@@ -83,7 +83,8 @@ def get_ctc_feeder_data(data, place, need_label=True):
pixel_tensor = core.LoDTensor()
pixel_data = None
pixel_data = np.concatenate(
list(map(lambda x: x[0][np.newaxis, :], data)), axis=0).astype("float32")
list(map(lambda x: x[0][np.newaxis, :], data)),
axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
label_tensor = to_lodtensor(list(map(lambda x: x[1], data)), place)
if need_label:
......@@ -92,11 +93,16 @@ def get_ctc_feeder_data(data, place, need_label=True):
return {"pixel": pixel_tensor}
def get_ctc_feeder_for_infer(data, place):
return get_ctc_feeder_data(data, place, need_label=False)
def get_attention_feeder_data(data, place, need_label=True):
pixel_tensor = core.LoDTensor()
pixel_data = None
pixel_data = np.concatenate(
list(map(lambda x: x[0][np.newaxis, :], data)), axis=0).astype("float32")
list(map(lambda x: x[0][np.newaxis, :], data)),
axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
label_in_tensor = to_lodtensor(list(map(lambda x: x[1], data)), place)
label_out_tensor = to_lodtensor(list(map(lambda x: x[2], data)), place)
......@@ -127,7 +133,8 @@ def get_attention_feeder_for_infer(data, place):
pixel_tensor = core.LoDTensor()
pixel_data = None
pixel_data = np.concatenate(
list(map(lambda x: x[0][np.newaxis, :], data)), axis=0).astype("float32")
list(map(lambda x: x[0][np.newaxis, :], data)),
axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
return {
"pixel": pixel_tensor,
......
......@@ -7,7 +7,7 @@ distillers:
distillation_loss_weight: 1
l2_distiller:
class: 'L2Distiller'
teacher_feature_map: 'fc_1.tmp_0'
teacher_feature_map: 'res_fc.tmp_0'
student_feature_map: 'fc_0.tmp_0'
distillation_loss_weight: 1
strategies:
......
......@@ -9,7 +9,7 @@ distillers:
distillation_loss_weight: 1
l2_distiller:
class: 'L2Distiller'
teacher_feature_map: 'fc_1.tmp_0'
teacher_feature_map: 'res_fc.tmp_0'
student_feature_map: 'fc_0.tmp_0'
distillation_loss_weight: 1
strategies:
......
......@@ -34,18 +34,22 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
# Fixing name conflicts in distillation
mv ResNet50_pretrained/conv1_weights ResNet50_pretrained/res_conv1_weights
mv ResNet50_pretrained/fc_0.w_0 ResNet50_pretrained/res_fc.w_0
mv ResNet50_pretrained/fc_0.b_0 ResNet50_pretrained/res_fc.b_0
cd ${pretrain_dir}/ResNet50_pretrained
mv conv1_weights res_conv1_weights
mv fc_0.w_0 res_fc.w_0
mv fc_0.b_0 res_fc.b_0
cd -
python compress.py \
--model "MobileNet" \
--teacher_model "ResNet50" \
--teacher_pretrained_model ./pretrain/ResNet50_pretrained \
--compress_config ./configs/mobilenetv1_resnet50_distillation.yaml
mv ResNet50_pretrained/res_conv1_weights ResNet50_pretrained/conv1_weights
mv ResNet50_pretrained/res_fc.w_0 ResNet50_pretrained/fc_0.w_0
mv ResNet50_pretrained/res_fc.b_0 ResNet50_pretrained/fc_0.b_0
cd ${pretrain_dir}/ResNet50_pretrained
mv res_conv1_weights conv1_weights
mv res_fc.w_0 fc_0.w_0
mv res_fc.b_0 fc_0.b_0
cd -
# for sensitivity filter pruning
#-------------------------------
......@@ -74,28 +78,32 @@ mv ResNet50_pretrained/res_fc.b_0 ResNet50_pretrained/fc_0.b_0
# for distillation with quantization
#-----------------------------------
#export CUDA_VISIBLE_DEVICES=0
#export CUDA_VISIBLE_DEVICES=4,5,6,7
#
## Fixing name conflicts in distillation
#mv ResNet50_pretrained/conv1_weights ResNet50_pretrained/res_conv1_weights
#mv ResNet50_pretrained/fc_0.w_0 ResNet50_pretrained/res_fc.w_0
#mv ResNet50_pretrained/fc_0.b_0 ResNet50_pretrained/res_fc.b_0
#cd ${pretrain_dir}/ResNet50_pretrained
#mv conv1_weights res_conv1_weights
#mv fc_0.w_0 res_fc.w_0
#mv fc_0.b_0 res_fc.b_0
#cd -
#
#python compress.py \
#--model "MobileNet" \
#--teacher_model "ResNet50" \
#--teacher_pretrained_model ./data/pretrain/ResNet50_pretrained \
#--teacher_pretrained_model ./pretrain/ResNet50_pretrained \
#--compress_config ./configs/quantization_dist.yaml
#
#mv ResNet50_pretrained/res_conv1_weights ResNet50_pretrained/conv1_weights
#mv ResNet50_pretrained/res_fc.w_0 ResNet50_pretrained/fc_0.w_0
#mv ResNet50_pretrained/res_fc.b_0 ResNet50_pretrained/fc_0.b_0
#cd ${pretrain_dir}/ResNet50_pretrained
#mv res_conv1_weights conv1_weights
#mv res_fc.w_0 fc_0.w_0
#mv res_fc.b_0 fc_0.b_0
#cd -
# for uniform filter pruning with quantization
#---------------------------------------------
#export CUDA_VISIBLE_DEVICES=0
#python compress.py \
#--model "MobileNet" \
#--pretrained_model ./data/pretrain/MobileNetV1_pretrained \
#--pretrained_model ./pretrain/MobileNetV1_pretrained \
#--compress_config ./configs/quantization_pruning.yaml
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册