“c12e970e274541e817aeac3d19024a09b1a46559”上不存在“dev/doc/tutorials/image_classification/index_en.html”
未验证 提交 3d0755b1 编写于 作者: C Chang Xu 提交者: GitHub

Add Early Stop in AutoCompression (#1358)

上级 d6f9aafc
...@@ -714,7 +714,10 @@ class AutoCompression: ...@@ -714,7 +714,10 @@ class AutoCompression:
best_metric = -1.0 best_metric = -1.0
total_epochs = train_config.epochs if train_config.epochs else 100 total_epochs = train_config.epochs if train_config.epochs else 100
total_train_iter = 0 total_train_iter = 0
stop_training = False
for epoch_id in range(total_epochs): for epoch_id in range(total_epochs):
if stop_training:
break
for batch_id, data in enumerate(self.train_dataloader()): for batch_id, data in enumerate(self.train_dataloader()):
np_probs_float, = self._exe.run(train_program_info.program, \ np_probs_float, = self._exe.run(train_program_info.program, \
feed=data, \ feed=data, \
...@@ -760,6 +763,10 @@ class AutoCompression: ...@@ -760,6 +763,10 @@ class AutoCompression:
abs(best_metric - abs(best_metric -
self.metric_before_compressed) self.metric_before_compressed)
) / self.metric_before_compressed <= 0.005: ) / self.metric_before_compressed <= 0.005:
_logger.info(
"The error rate between the compressed model and original model is less than 5%. The training process ends."
)
stop_training = True
break break
else: else:
_logger.info( _logger.info(
...@@ -767,14 +774,18 @@ class AutoCompression: ...@@ -767,14 +774,18 @@ class AutoCompression:
format(epoch_id, metric, best_metric)) format(epoch_id, metric, best_metric))
if train_config.target_metric is not None: if train_config.target_metric is not None:
if metric > float(train_config.target_metric): if metric > float(train_config.target_metric):
stop_training = True
_logger.info(
"The metric of compressed model has reached the target metric. The training process ends."
)
break break
else: else:
_logger.warning( _logger.warning(
"Not set eval function, so unable to test accuracy performance." "Not set eval function, so unable to test accuracy performance."
) )
if train_config.train_iter and total_train_iter >= train_config.train_iter: if (train_config.train_iter and total_train_iter >=
epoch_id = total_epochs train_config.train_iter) or stop_training:
break break
if 'unstructure' in self._strategy or train_config.sparse_model: if 'unstructure' in self._strategy or train_config.sparse_model:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册