From b67904d52897fb1d8f42567ee054e36e245674fa Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Fri, 29 Apr 2022 07:22:35 +0000 Subject: [PATCH] fix doc for kd --- .../PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml | 2 +- doc/doc_ch/knowledge_distillation.md | 18 +- doc/doc_en/knowledge_distillation_en.md | 170 ++++++++++++------ 3 files changed, 129 insertions(+), 61 deletions(-) diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml index e7cbae59..773a3649 100644 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml @@ -129,7 +129,7 @@ Loss: key: head_out multi_head: True - DistillationSARLoss: - weight: 1.0 + weight: 0.5 model_name_list: ["Student", "Teacher"] key: head_out multi_head: True diff --git a/doc/doc_ch/knowledge_distillation.md b/doc/doc_ch/knowledge_distillation.md index cbf5a927..59341554 100644 --- a/doc/doc_ch/knowledge_distillation.md +++ b/doc/doc_ch/knowledge_distillation.md @@ -259,18 +259,18 @@ Loss: model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充 - ["Student", "Teacher"] key: head_out # 取子网络输出dict中,该key对应的tensor - multi_head: True # 是否为多头结构,我们 - dis_head: ctc # 蒸馏 + multi_head: True # 是否为多头结构 + dis_head: ctc # 指定用于计算损失函数的head name: dml_ctc # 蒸馏loss的前缀名称,避免不同loss之间的命名冲突 - DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss - weight: 1.0 # 权重 + weight: 0.5 # 权重 act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None use_log: true # 对输入计算log,如果函数已经 model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充 - ["Student", "Teacher"] key: head_out # 取子网络输出dict中,该key对应的tensor - multi_head: True # 是否为多头结构,我们 - dis_head: sar # 蒸馏 + multi_head: True # 是否为多头结构 + dis_head: sar # 指定用于计算损失函数的head name: dml_sar # 蒸馏loss的前缀名称,避免不同loss之间的命名冲突 - DistillationDistanceLoss: # 蒸馏的距离损失函数 weight: 1.0 # 权重 @@ -286,17 +286,17 @@ Loss: weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段 model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss key: head_out # 取子网络输出dict中,该key对应的tensor - multi_head: True # 是否为多头结构,为true时,取出其中的 + multi_head: True # 是否为多头结构,为true时,取出其中的SAR分支计算损失函数 ``` 上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。 -以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。 +以上述配置为例,最终蒸馏训练的损失函数包含下面5个部分。 - `Student`和`Teacher`最终输出(`head_out`)的CTC分支与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。 -- `Student`和`Teacher`最终输出(`head_out`)的SAR分支与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。 +- `Student`和`Teacher`最终输出(`head_out`)的SAR分支与gt的SAR loss,权重为1.0。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。 - `Student`和`Teacher`最终输出(`head_out`)的CTC分支之间的DML loss,权重为1。 -- `Student`和`Teacher`最终输出(`head_out`)SARC分支之间的DML loss,权重为1。 +- `Student`和`Teacher`最终输出(`head_out`)的SAR分支之间的DML loss,权重为0.5。 - `Student`和`Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss,权重为1。 diff --git a/doc/doc_en/knowledge_distillation_en.md b/doc/doc_en/knowledge_distillation_en.md index 1db9faef..da1e7152 100755 --- a/doc/doc_en/knowledge_distillation_en.md +++ b/doc/doc_en/knowledge_distillation_en.md @@ -74,6 +74,7 @@ The configuration file is in [ch_PP-OCRv2_rec_distillation.yml](../../configs/re #### 2.1.1 Model Structure In the knowledge distillation task, the model structure configuration is as follows. + ```yaml Architecture: model_type: &model_type "rec" # Model category, recognition, detection, etc. @@ -85,37 +86,55 @@ Architecture: freeze_params: false # Do you need fixed parameters return_all_feats: true # Do you need to return all features, if it is False, only the final output is returned model_type: *model_type # Model category - algorithm: CRNN # The algorithm name of the sub-network. The remaining parameters of the sub-network are consistent with the general model training configuration + algorithm: SVTR # The algorithm name of the sub-network. The remaining parameters of the sub-network are consistent with the general model training configuration Transform: Backbone: name: MobileNetV1Enhance scale: 0.5 - Neck: - name: SequenceEncoder - encoder_type: rnn - hidden_size: 64 + last_conv_stride: [1, 2] + last_pool_type: avg Head: - name: CTCHead - mid_channels: 96 - fc_decay: 0.00002 + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 64 + depth: 2 + hidden_dims: 120 + use_guide: True + Head: + fc_decay: 0.00001 + - SARHead: + enc_dim: 512 + max_text_length: *max_text_length Student: # Another sub-network, here is a distillation example of DML, the two sub-networks have the same structure, and both need to learn parameters pretrained: # The following parameters are the same as above freeze_params: false return_all_feats: true model_type: *model_type - algorithm: CRNN + algorithm: SVTR Transform: Backbone: name: MobileNetV1Enhance scale: 0.5 - Neck: - name: SequenceEncoder - encoder_type: rnn - hidden_size: 64 + last_conv_stride: [1, 2] + last_pool_type: avg Head: - name: CTCHead - mid_channels: 96 - fc_decay: 0.00002 + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 64 + depth: 2 + hidden_dims: 120 + use_guide: True + Head: + fc_decay: 0.00001 + - SARHead: + enc_dim: 512 + max_text_length: *max_text_length ``` If you want to add more sub-networks for training, you can also add the corresponding fields in the configuration file according to the way of adding `Student` and `Teacher`. @@ -132,55 +151,83 @@ Architecture: freeze_params: false return_all_feats: true model_type: *model_type - algorithm: CRNN + algorithm: SVTR Transform: Backbone: name: MobileNetV1Enhance scale: 0.5 - Neck: - name: SequenceEncoder - encoder_type: rnn - hidden_size: 64 + last_conv_stride: [1, 2] + last_pool_type: avg Head: - name: CTCHead - mid_channels: 96 - fc_decay: 0.00002 + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 64 + depth: 2 + hidden_dims: 120 + use_guide: True + Head: + fc_decay: 0.00001 + - SARHead: + enc_dim: 512 + max_text_length: *max_text_length Student: pretrained: freeze_params: false return_all_feats: true model_type: *model_type - algorithm: CRNN + algorithm: SVTR Transform: Backbone: name: MobileNetV1Enhance scale: 0.5 - Neck: - name: SequenceEncoder - encoder_type: rnn - hidden_size: 64 + last_conv_stride: [1, 2] + last_pool_type: avg Head: - name: CTCHead - mid_channels: 96 - fc_decay: 0.00002 - Student2: # The new sub-network introduced in the knowledge distillation task, the configuration is the same as above + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 64 + depth: 2 + hidden_dims: 120 + use_guide: True + Head: + fc_decay: 0.00001 + - SARHead: + enc_dim: 512 + max_text_length: *max_text_length + Student2: pretrained: freeze_params: false return_all_feats: true model_type: *model_type - algorithm: CRNN + algorithm: SVTR Transform: Backbone: name: MobileNetV1Enhance scale: 0.5 - Neck: - name: SequenceEncoder - encoder_type: rnn - hidden_size: 64 + last_conv_stride: [1, 2] + last_pool_type: avg Head: - name: CTCHead - mid_channels: 96 - fc_decay: 0.00002 + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 64 + depth: 2 + hidden_dims: 120 + use_guide: True + Head: + fc_decay: 0.00001 + - SARHead: + enc_dim: 512 + max_text_length: *max_text_length +``` ``` When the model is finally trained, it contains 3 sub-networks: `Teacher`, `Student`, `Student2`. @@ -224,23 +271,42 @@ Loss: act: "softmax" # Activation function, use it to process the input, can be softmax, sigmoid or None, the default is None model_name_pairs: # The subnet name pair used to calculate DML loss. If you want to calculate the DML loss of other subnets, you can continue to add it below the list - ["Student", "Teacher"] - key: head_out + key: head_out + multi_head: True # whether to use mult_head + dis_head: ctc # assign the head name to calculate loss + name: dml_ctc # prefix name of the loss + - DistillationDMLLoss: # DML loss function, inherited from the standard DMLLoss + weight: 0.5 + act: "softmax" # Activation function, use it to process the input, can be softmax, sigmoid or None, the default is None + model_name_pairs: # The subnet name pair used to calculate DML loss. If you want to calculate the DML loss of other subnets, you can continue to add it below the list + - ["Student", "Teacher"] + key: head_out + multi_head: True # whether to use mult_head + dis_head: sar # assign the head name to calculate loss + name: dml_sar # prefix name of the loss - DistillationDistanceLoss: # Distilled distance loss function weight: 1.0 mode: "l2" # Support l1, l2 or smooth_l1 model_name_pairs: # Calculate the distance loss of the subnet name pair - ["Student", "Teacher"] key: backbone_out + - DistillationSARLoss: # SAR loss function based on distillation, inherited from standard SAR loss + weight: 1.0 # The weight of the loss function. In loss_config_list, each loss function must include this field + model_name_list: ["Student", "Teacher"] # For the prediction results of the distillation model, extract the output of these two sub-networks and calculate the SAR loss with gt + key: head_out # In the sub-network output dict, take the corresponding tensor + multi_head: True # whether it is multi-head or not, if true, SAR branch is used to calculate the loss ``` Among the above loss functions, all distillation loss functions are inherited from the standard loss function class. The main functions are: Analyze the output of the distillation model, find the intermediate node (tensor) used to calculate the loss, and then use the standard loss function class to calculate. -Taking the above configuration as an example, the final distillation training loss function contains the following three parts. +Taking the above configuration as an example, the final distillation training loss function contains the following five parts. -- The final output `head_out` of `Student` and `Teacher` calculates the CTC loss with gt (loss weight equals 1.0). Here, because both sub-networks need to update the parameters, both of them need to calculate the loss with gt. -- DML loss between `Student` and `Teacher`'s final output `head_out` (loss weight equals 1.0). +- CTC branch of the final output `head_out` for `Student` and `Teacher` calculates the CTC loss with gt (loss weight equals 1.0). Here, because both sub-networks need to update the parameters, both of them need to calculate the loss with gt. +- SAR branch of the final output `head_out` for `Student` and `Teacher` calculates the SAR loss with gt (loss weight equals 1.0). Here, because both sub-networks need to update the parameters, both of them need to calculate the loss with gt. +- DML loss between CTC branch of `Student` and `Teacher`'s final output `head_out` (loss weight equals 1.0). +- DML loss between SAR branch of `Student` and `Teacher`'s final output `head_out` (loss weight equals 0.5). - L2 loss between `Student` and `Teacher`'s backbone network output `backbone_out` (loss weight equals 1.0). For more specific implementation of `CombinedLoss`, please refer to: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23). @@ -257,6 +323,7 @@ PostProcess: name: DistillationCTCLabelDecode # CTC decoding post-processing of distillation tasks, inherited from the standard CTCLabelDecode class model_name: ["Student", "Teacher"] # For the prediction results of the distillation model, extract the outputs of these two sub-networks and decode them key: head_out # Take the corresponding tensor in the subnet output dict + multi_head: True # whether it is multi-head or not, if true, CTC branch is used to calculate the loss ``` Taking the above configuration as an example, the CTC decoding output of the two sub-networks `Student` and `Teahcer` will be calculated at the same time. @@ -276,6 +343,7 @@ Metric: base_metric_name: RecMetric # The base class of indicator calculation. For the output of the model, the indicator will be calculated based on this class main_indicator: acc # The name of the indicator key: "Student" # Select the main_indicator of this subnet as the criterion for saving the best model + ignore_space: False # whether to ignore space during evaulation ``` Taking the above configuration as an example, the accuracy metric of the `Student` subnet will be used as the judgment metric for saving the best model. @@ -289,13 +357,13 @@ For more specific implementation of `DistillationMetric`, please refer to: [dist There are two ways to fine-tune the recognition distillation task. -1. Fine-tuning based on knowledge distillation: this situation is relatively simple, download the pre-trained model. Then configure the pre-training model path and your own data path in [ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml) to perform fine-tuning training of the model. +1. Fine-tuning based on knowledge distillation: this situation is relatively simple, download the pre-trained model. Then configure the pre-training model path and your own data path in [ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml) to perform fine-tuning training of the model. 2. Do not use knowledge distillation in fine-tuning: In this case, you need to first extract the student model parameters from the pre-training model. The specific steps are as follows. - First download the pre-trained model and unzip it. ```shell -wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar -tar -xf ch_PP-OCRv2_rec_train.tar +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar +tar -xf ch_PP-OCRv3_rec_train.tar ``` - Then use python to extract the student model parameters @@ -303,7 +371,7 @@ tar -xf ch_PP-OCRv2_rec_train.tar ```python import paddle # Load the pre-trained model -all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams") +all_params = paddle.load("ch_PP-OCRv3_rec_train/best_accuracy.pdparams") # View the keys of the weight parameter print(all_params.keys()) # Weight extraction of student model @@ -311,10 +379,10 @@ s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Stu # View the keys of the weight parameters of the student model print(s_params.keys()) # Save weight parameters -paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams") +paddle.save(s_params, "ch_PP-OCRv3_rec_train/student.pdparams") ``` -After the extraction is complete, use [ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml) to modify the path of the pre-trained model (the path of the exported `student.pdparams` model) and your own data path to fine-tune the model. +After the extraction is complete, use [ch_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml) to modify the path of the pre-trained model (the path of the exported `student.pdparams` model) and your own data path to fine-tune the model. ### 2.2 Detection Model Configuration File Analysis -- GitLab