未验证 提交 9f8f1a4e 编写于 作者: C ceci3 提交者: GitHub

Fix transformer prune reorder (#1194)

上级 b3499dc9
......@@ -76,8 +76,9 @@ EXPERIENCE_STRATEGY_WITHOUT_LOSS = [
]
MAGIC_SPARSE_RATIO = 0.75
### TODO: 0.02 threshold maybe not suitable, need to check
MAGIC_MAX_EMD_DISTANCE = 0.02
MAGIC_MIN_EMD_DISTANCE = 0.01
### NOTE: reduce magic data to choose quantization aware training.
MAGIC_MAX_EMD_DISTANCE = 0.0002 #0.02
MAGIC_MIN_EMD_DISTANCE = 0.0001 #0.01
DEFAULT_TRANSFORMER_STRATEGY = 'prune_0.25_int8'
DEFAULT_STRATEGY = 'origin_int8'
......
......@@ -204,6 +204,23 @@ def softmax_with_cross_entropy_op(block, logits, labels):
return loss, softmax
def kl_div_op(block, logits, labels):
""" Insert kl_div op to program"""
global global_idx
loss = block.create_var(
name='{}.kl_div_tmp_{}'.format(logits.name, global_idx),
shape=logits.shape,
dtype=logits.dtype)
global_idx += 1
attrs = {'reduction': "mean"} ### maybe take a long time use this attrs
inputs = {'X': logits, 'Target': labels}
outputs = {'Loss': loss}
block.append_op(
type='kldiv_loss', inputs=inputs, outputs=outputs, attrs=attrs)
return loss
def mean_op(block, inputs, axis=None, keepdim=False):
""" Insert mean op to program"""
global global_idx
......@@ -331,9 +348,12 @@ class TransformerPruner:
dtype=label_info['dtype'],
persistable=False)
labels = feed_op(block, feed_num, labels)
ce_loss, probs = softmax_with_cross_entropy_op(
if label_info['dtype'] == np.float32:
loss = kl_div_op(block, logits=logits, labels=labels)
else:
loss, probs = softmax_with_cross_entropy_op(
block, logits=logits, labels=labels)
loss = mean_op(block, ce_loss)
loss = mean_op(block, loss)
program._sync_with_cpp()
paddle.static.append_backward(loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册