未验证 提交 86f6115c 编写于 作者: Z zhouzj 提交者: GitHub

[ACT] add loss info (#1597)

* add loss info on ACT training.

* Add flops info.
上级 dab888af
...@@ -28,7 +28,19 @@ def analysis_prune(eval_function, ...@@ -28,7 +28,19 @@ def analysis_prune(eval_function,
params_filename, params_filename,
analysis_file, analysis_file,
pruned_ratios, pruned_ratios,
target_loss=None): target_loss=None,
criterion='l1_norm'):
'''
Args:
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset.
model_dir(str): Directory path to load model. If you want to load onnx model, only set ``model_dir=model.onnx``.
model_filename(str): Specify model_filename. If you want to load onnx model, model filename should be None.
params_filename(str): Specify params_filename. If you want to load onnx model, params filename should be None.
analysis_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
pruned_ratios(list): The ratios to be pruned.
criterion(str|function): The criterion used to sort channels for pruning. Currently supports l1_ norm, bn_scale, geometry_median. Default: l1_norm.
'''
devices = paddle.device.get_device().split(':')[0] devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices) places = paddle.device._convert_to_place(devices)
exe = paddle.static.Executor(places) exe = paddle.static.Executor(places)
...@@ -47,7 +59,8 @@ def analysis_prune(eval_function, ...@@ -47,7 +59,8 @@ def analysis_prune(eval_function,
eval_function, eval_function,
sensitivities_file=analysis_file, sensitivities_file=analysis_file,
eval_args=[exe, feed_target_names, fetch_targets], eval_args=[exe, feed_target_names, fetch_targets],
pruned_ratios=pruned_ratios) pruned_ratios=pruned_ratios,
criterion=criterion)
with open(analysis_file, 'rb') as f: with open(analysis_file, 'rb') as f:
if sys.version_info < (3, 0): if sys.version_info < (3, 0):
......
...@@ -783,13 +783,17 @@ class AutoCompression: ...@@ -783,13 +783,17 @@ class AutoCompression:
total_epochs = train_config.epochs if train_config.epochs else 100 total_epochs = train_config.epochs if train_config.epochs else 100
total_train_iter = 0 total_train_iter = 0
stop_training = False stop_training = False
loss_vars = [var for var in train_program_info.loss_dict.values()]
loss_names = [name for name in train_program_info.loss_dict.keys()]
for epoch_id in range(total_epochs): for epoch_id in range(total_epochs):
if stop_training: if stop_training:
break break
for batch_id, data in enumerate(self.train_dataloader()): for batch_id, data in enumerate(self.train_dataloader()):
np_probs_float, = self._exe.run(train_program_info.program, \ loss = self._exe.run(train_program_info.program, \
feed=data, \ feed=data, \
fetch_list=train_program_info.fetch_targets) fetch_list=train_program_info.fetch_targets+loss_vars)
if not isinstance(train_program_info.learning_rate, float): if not isinstance(train_program_info.learning_rate, float):
train_program_info.learning_rate.step() train_program_info.learning_rate.step()
if 'unstructure' in strategy: if 'unstructure' in strategy:
...@@ -800,10 +804,12 @@ class AutoCompression: ...@@ -800,10 +804,12 @@ class AutoCompression:
else: else:
logging_iter = train_config.logging_iter logging_iter = train_config.logging_iter
if batch_id % int(logging_iter) == 0: if batch_id % int(logging_iter) == 0:
_logger.info( print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
"Total iter: {}, epoch: {}, batch: {}, loss: {}".format( total_train_iter, epoch_id, batch_id, loss[0])
total_train_iter, epoch_id, batch_id, for idx, loss_value in enumerate(loss[1:]):
np_probs_float)) print_info += '{}: {} '.format(loss_names[idx],
loss_value)
_logger.info(print_info)
total_train_iter += 1 total_train_iter += 1
if total_train_iter % int( if total_train_iter % int(
train_config.eval_iter) == 0 and total_train_iter != 0: train_config.eval_iter) == 0 and total_train_iter != 0:
......
...@@ -24,6 +24,7 @@ from ..common.recover_program import recover_inference_program, _remove_fetch_no ...@@ -24,6 +24,7 @@ from ..common.recover_program import recover_inference_program, _remove_fetch_no
from ..common import get_logger from ..common import get_logger
from .strategy_config import ProgramInfo from .strategy_config import ProgramInfo
from ..common.load_model import load_inference_model from ..common.load_model import load_inference_model
from ..analysis import flops
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
__all__ = [ __all__ = [
...@@ -133,7 +134,7 @@ def _parse_distill_loss(distill_node_pair, ...@@ -133,7 +134,7 @@ def _parse_distill_loss(distill_node_pair,
distill_lambda=1.0): distill_lambda=1.0):
"""parse distill loss config""" """parse distill loss config"""
loss_dist = 0.0 loss_dist = 0.0
losses = [] losses = {}
if isinstance(distill_node_pair[0], str): if isinstance(distill_node_pair[0], str):
assert isinstance(distill_loss, str) assert isinstance(distill_loss, str)
assert isinstance(distill_lambda, float) assert isinstance(distill_lambda, float)
...@@ -143,16 +144,17 @@ def _parse_distill_loss(distill_node_pair, ...@@ -143,16 +144,17 @@ def _parse_distill_loss(distill_node_pair,
assert len(distill_node_pair) == len(distill_loss) assert len(distill_node_pair) == len(distill_loss)
assert len(distill_node_pair) == len(distill_lambda) assert len(distill_node_pair) == len(distill_lambda)
for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda): for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
tmp_loss = 0.0 distill_lambda):
_logger.info("train config.distill_node_pair: {}".format(node, loss, tmp_loss = losses.get(loss_clas, 0.0)
lam)) _logger.info("train config.distill_node_pair: {}".format(
node, loss_clas, lam))
assert len(node) % 2 == 0, \ assert len(node) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number" "distill_node_pair config wrong, the length needs to be an even number"
for i in range(len(node) // 2): for i in range(len(node) // 2):
tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1]) tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam
loss_dist += lam * tmp_loss loss_dist += tmp_loss
losses.append(tmp_loss) losses[loss_clas] = tmp_loss
return loss_dist, losses return loss_dist, losses
...@@ -364,7 +366,7 @@ def build_distill_program(executor, ...@@ -364,7 +366,7 @@ def build_distill_program(executor,
use_dynamic_loss_scaling=True, use_dynamic_loss_scaling=True,
**train_config['amp_config']) **train_config['amp_config'])
distill_loss, losses = _parse_distill_loss( distill_loss, loss_dict = _parse_distill_loss(
distill_node_pair, distill_node_pair,
config.get('loss') or 'l2', ### default loss is l2 config.get('loss') or 'l2', ### default loss is l2
config.get('alpha') or 1.0) ### default alpha is 1.0 config.get('alpha') or 1.0) ### default alpha is 1.0
...@@ -385,7 +387,7 @@ def build_distill_program(executor, ...@@ -385,7 +387,7 @@ def build_distill_program(executor,
train_program_info = ProgramInfo(startup_program, train_program, train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, train_fetch_list, feed_target_names, train_fetch_list,
optimizer, learning_rate) optimizer, learning_rate, loss_dict)
test_program_info = ProgramInfo(startup_program, test_program, test_program_info = ProgramInfo(startup_program, test_program,
feed_target_names, fetch_targets) feed_target_names, fetch_targets)
return train_program_info, test_program_info return train_program_info, test_program_info
...@@ -520,6 +522,8 @@ def build_prune_program(executor, ...@@ -520,6 +522,8 @@ def build_prune_program(executor,
params.append(param.name) params.append(param.name)
original_shapes[param.name] = param.shape original_shapes[param.name] = param.shape
origin_flops = flops(train_program_info.program)
pruned_program, _, _ = pruner.prune( pruned_program, _, _ = pruner.prune(
train_program_info.program, train_program_info.program,
paddle.static.global_scope(), paddle.static.global_scope(),
...@@ -530,12 +534,18 @@ def build_prune_program(executor, ...@@ -530,12 +534,18 @@ def build_prune_program(executor,
place=place) place=place)
_logger.info( _logger.info(
"####################channel pruning##########################") "####################channel pruning##########################")
for param in pruned_program.all_parameters(): for param in pruned_program.global_block().all_parameters():
if param.name in original_shapes: if param.name in original_shapes:
_logger.info("{}, from {} to {}".format( _logger.info("{}, from {} to {}".format(
param.name, original_shapes[param.name], param.shape)) param.name, original_shapes[param.name], param.shape))
_logger.info( _logger.info(
"####################channel pruning end##########################") "####################channel pruning end##########################")
final_flops = flops(pruned_program)
pruned_flops = abs(origin_flops - final_flops) / origin_flops
_logger.info("FLOPs before pruning: {}".format(origin_flops))
_logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format(
final_flops, round(pruned_flops * 100, 2)))
train_program_info.program = pruned_program train_program_info.program = pruned_program
elif strategy.startswith('asp'): elif strategy.startswith('asp'):
......
...@@ -431,7 +431,8 @@ class ProgramInfo: ...@@ -431,7 +431,8 @@ class ProgramInfo:
feed_target_names, feed_target_names,
fetch_targets, fetch_targets,
optimizer=None, optimizer=None,
learning_rate=None): learning_rate=None,
loss_dict=None):
""" """
ProgramInfo Config. ProgramInfo Config.
Args: Args:
...@@ -441,6 +442,7 @@ class ProgramInfo: ...@@ -441,6 +442,7 @@ class ProgramInfo:
fetch_targets(list(Variable)): The fetch variable in the program. fetch_targets(list(Variable)): The fetch variable in the program.
optimizer(Optimizer, optional): Optimizer in training. Default: None. optimizer(Optimizer, optional): Optimizer in training. Default: None.
learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None. learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None.
loss_dict(dict): The components of losses.
""" """
self.startup_program = startup_program self.startup_program = startup_program
self.program = program self.program = program
...@@ -448,3 +450,4 @@ class ProgramInfo: ...@@ -448,3 +450,4 @@ class ProgramInfo:
self.fetch_targets = fetch_targets self.fetch_targets = fetch_targets
self.optimizer = optimizer self.optimizer = optimizer
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.loss_dict = loss_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册