未验证 提交 86f6115c 编写于 作者: Z zhouzj 提交者: GitHub

[ACT] add loss info (#1597)

* add loss info on ACT training.

* Add flops info.
上级 dab888af
......@@ -28,7 +28,19 @@ def analysis_prune(eval_function,
params_filename,
analysis_file,
pruned_ratios,
target_loss=None):
target_loss=None,
criterion='l1_norm'):
'''
Args:
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset.
model_dir(str): Directory path to load model. If you want to load onnx model, only set ``model_dir=model.onnx``.
model_filename(str): Specify model_filename. If you want to load onnx model, model filename should be None.
params_filename(str): Specify params_filename. If you want to load onnx model, params filename should be None.
analysis_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
pruned_ratios(list): The ratios to be pruned.
criterion(str|function): The criterion used to sort channels for pruning. Currently supports l1_ norm, bn_scale, geometry_median. Default: l1_norm.
'''
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
exe = paddle.static.Executor(places)
......@@ -47,7 +59,8 @@ def analysis_prune(eval_function,
eval_function,
sensitivities_file=analysis_file,
eval_args=[exe, feed_target_names, fetch_targets],
pruned_ratios=pruned_ratios)
pruned_ratios=pruned_ratios,
criterion=criterion)
with open(analysis_file, 'rb') as f:
if sys.version_info < (3, 0):
......
......@@ -783,13 +783,17 @@ class AutoCompression:
total_epochs = train_config.epochs if train_config.epochs else 100
total_train_iter = 0
stop_training = False
loss_vars = [var for var in train_program_info.loss_dict.values()]
loss_names = [name for name in train_program_info.loss_dict.keys()]
for epoch_id in range(total_epochs):
if stop_training:
break
for batch_id, data in enumerate(self.train_dataloader()):
np_probs_float, = self._exe.run(train_program_info.program, \
loss = self._exe.run(train_program_info.program, \
feed=data, \
fetch_list=train_program_info.fetch_targets)
fetch_list=train_program_info.fetch_targets+loss_vars)
if not isinstance(train_program_info.learning_rate, float):
train_program_info.learning_rate.step()
if 'unstructure' in strategy:
......@@ -800,10 +804,12 @@ class AutoCompression:
else:
logging_iter = train_config.logging_iter
if batch_id % int(logging_iter) == 0:
_logger.info(
"Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
total_train_iter, epoch_id, batch_id,
np_probs_float))
print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
total_train_iter, epoch_id, batch_id, loss[0])
for idx, loss_value in enumerate(loss[1:]):
print_info += '{}: {} '.format(loss_names[idx],
loss_value)
_logger.info(print_info)
total_train_iter += 1
if total_train_iter % int(
train_config.eval_iter) == 0 and total_train_iter != 0:
......
......@@ -24,6 +24,7 @@ from ..common.recover_program import recover_inference_program, _remove_fetch_no
from ..common import get_logger
from .strategy_config import ProgramInfo
from ..common.load_model import load_inference_model
from ..analysis import flops
_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
......@@ -133,7 +134,7 @@ def _parse_distill_loss(distill_node_pair,
distill_lambda=1.0):
"""parse distill loss config"""
loss_dist = 0.0
losses = []
losses = {}
if isinstance(distill_node_pair[0], str):
assert isinstance(distill_loss, str)
assert isinstance(distill_lambda, float)
......@@ -143,16 +144,17 @@ def _parse_distill_loss(distill_node_pair,
assert len(distill_node_pair) == len(distill_loss)
assert len(distill_node_pair) == len(distill_lambda)
for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda):
tmp_loss = 0.0
_logger.info("train config.distill_node_pair: {}".format(node, loss,
lam))
for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
distill_lambda):
tmp_loss = losses.get(loss_clas, 0.0)
_logger.info("train config.distill_node_pair: {}".format(
node, loss_clas, lam))
assert len(node) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number"
for i in range(len(node) // 2):
tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1])
loss_dist += lam * tmp_loss
losses.append(tmp_loss)
tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam
loss_dist += tmp_loss
losses[loss_clas] = tmp_loss
return loss_dist, losses
......@@ -364,7 +366,7 @@ def build_distill_program(executor,
use_dynamic_loss_scaling=True,
**train_config['amp_config'])
distill_loss, losses = _parse_distill_loss(
distill_loss, loss_dict = _parse_distill_loss(
distill_node_pair,
config.get('loss') or 'l2', ### default loss is l2
config.get('alpha') or 1.0) ### default alpha is 1.0
......@@ -385,7 +387,7 @@ def build_distill_program(executor,
train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, train_fetch_list,
optimizer, learning_rate)
optimizer, learning_rate, loss_dict)
test_program_info = ProgramInfo(startup_program, test_program,
feed_target_names, fetch_targets)
return train_program_info, test_program_info
......@@ -520,6 +522,8 @@ def build_prune_program(executor,
params.append(param.name)
original_shapes[param.name] = param.shape
origin_flops = flops(train_program_info.program)
pruned_program, _, _ = pruner.prune(
train_program_info.program,
paddle.static.global_scope(),
......@@ -530,12 +534,18 @@ def build_prune_program(executor,
place=place)
_logger.info(
"####################channel pruning##########################")
for param in pruned_program.all_parameters():
for param in pruned_program.global_block().all_parameters():
if param.name in original_shapes:
_logger.info("{}, from {} to {}".format(
param.name, original_shapes[param.name], param.shape))
_logger.info(
"####################channel pruning end##########################")
final_flops = flops(pruned_program)
pruned_flops = abs(origin_flops - final_flops) / origin_flops
_logger.info("FLOPs before pruning: {}".format(origin_flops))
_logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format(
final_flops, round(pruned_flops * 100, 2)))
train_program_info.program = pruned_program
elif strategy.startswith('asp'):
......
......@@ -431,7 +431,8 @@ class ProgramInfo:
feed_target_names,
fetch_targets,
optimizer=None,
learning_rate=None):
learning_rate=None,
loss_dict=None):
"""
ProgramInfo Config.
Args:
......@@ -441,6 +442,7 @@ class ProgramInfo:
fetch_targets(list(Variable)): The fetch variable in the program.
optimizer(Optimizer, optional): Optimizer in training. Default: None.
learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None.
loss_dict(dict): The components of losses.
"""
self.startup_program = startup_program
self.program = program
......@@ -448,3 +450,4 @@ class ProgramInfo:
self.fetch_targets = fetch_targets
self.optimizer = optimizer
self.learning_rate = learning_rate
self.loss_dict = loss_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册