From 9f8f1a4e4f9df5d473f5da49ef8440b9d6d0602b Mon Sep 17 00:00:00 2001 From: ceci3 Date: Mon, 27 Jun 2022 20:03:12 +0800 Subject: [PATCH] Fix transformer prune reorder (#1194) --- paddleslim/auto_compression/auto_strategy.py | 5 ++-- .../auto_compression/transformer_pruner.py | 26 ++++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/paddleslim/auto_compression/auto_strategy.py b/paddleslim/auto_compression/auto_strategy.py index a6311a3d..9bfc850f 100644 --- a/paddleslim/auto_compression/auto_strategy.py +++ b/paddleslim/auto_compression/auto_strategy.py @@ -76,8 +76,9 @@ EXPERIENCE_STRATEGY_WITHOUT_LOSS = [ ] MAGIC_SPARSE_RATIO = 0.75 ### TODO: 0.02 threshold maybe not suitable, need to check -MAGIC_MAX_EMD_DISTANCE = 0.02 -MAGIC_MIN_EMD_DISTANCE = 0.01 +### NOTE: reduce magic data to choose quantization aware training. +MAGIC_MAX_EMD_DISTANCE = 0.0002 #0.02 +MAGIC_MIN_EMD_DISTANCE = 0.0001 #0.01 DEFAULT_TRANSFORMER_STRATEGY = 'prune_0.25_int8' DEFAULT_STRATEGY = 'origin_int8' diff --git a/paddleslim/auto_compression/transformer_pruner.py b/paddleslim/auto_compression/transformer_pruner.py index 0cb3011e..e21178d4 100644 --- a/paddleslim/auto_compression/transformer_pruner.py +++ b/paddleslim/auto_compression/transformer_pruner.py @@ -204,6 +204,23 @@ def softmax_with_cross_entropy_op(block, logits, labels): return loss, softmax +def kl_div_op(block, logits, labels): + """ Insert kl_div op to program""" + global global_idx + loss = block.create_var( + name='{}.kl_div_tmp_{}'.format(logits.name, global_idx), + shape=logits.shape, + dtype=logits.dtype) + global_idx += 1 + + attrs = {'reduction': "mean"} ### maybe take a long time use this attrs + inputs = {'X': logits, 'Target': labels} + outputs = {'Loss': loss} + block.append_op( + type='kldiv_loss', inputs=inputs, outputs=outputs, attrs=attrs) + return loss + + def mean_op(block, inputs, axis=None, keepdim=False): """ Insert mean op to program""" global global_idx @@ -331,9 +348,12 @@ class TransformerPruner: dtype=label_info['dtype'], persistable=False) labels = feed_op(block, feed_num, labels) - ce_loss, probs = softmax_with_cross_entropy_op( - block, logits=logits, labels=labels) - loss = mean_op(block, ce_loss) + if label_info['dtype'] == np.float32: + loss = kl_div_op(block, logits=logits, labels=labels) + else: + loss, probs = softmax_with_cross_entropy_op( + block, logits=logits, labels=labels) + loss = mean_op(block, loss) program._sync_with_cpp() paddle.static.append_backward(loss) -- GitLab