未验证 提交 a620089a 编写于 作者: G Guanghua Yu 提交者: GitHub

support train_iter in act (#1161)

上级 066b7056
...@@ -12,7 +12,6 @@ repos: ...@@ -12,7 +12,6 @@ repos:
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks: hooks:
- id: yapf - id: yapf
language_version: python3.9
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0 sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0
......
...@@ -34,7 +34,7 @@ Quantization: ...@@ -34,7 +34,7 @@ Quantization:
- depthwise_conv2d - depthwise_conv2d
TrainConfig: TrainConfig:
epochs: 1 train_iter: 3000
eval_iter: 1000 eval_iter: 1000
learning_rate: 0.00001 learning_rate: 0.00001
optimizer: SGD optimizer: SGD
......
...@@ -30,7 +30,7 @@ EvalDataset: ...@@ -30,7 +30,7 @@ EvalDataset:
use_gt_bbox: True use_gt_bbox: True
image_thre: 0.5 image_thre: 0.5
worker_num: 2 worker_num: 0
global_mean: &global_mean [0.485, 0.456, 0.406] global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225] global_std: &global_std [0.229, 0.224, 0.225]
......
...@@ -14,7 +14,7 @@ EvalDataset: ...@@ -14,7 +14,7 @@ EvalDataset:
anno_path: annotations/instances_val2017.json anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/ dataset_dir: dataset/coco/
worker_num: 4 worker_num: 0
# preprocess reader in test # preprocess reader in test
EvalReader: EvalReader:
......
...@@ -34,7 +34,7 @@ Quantization: ...@@ -34,7 +34,7 @@ Quantization:
- depthwise_conv2d - depthwise_conv2d
TrainConfig: TrainConfig:
epochs: 1 train_iter: 3000
eval_iter: 1000 eval_iter: 1000
learning_rate: 0.0001 learning_rate: 0.0001
optimizer: SGD optimizer: SGD
......
...@@ -14,7 +14,7 @@ EvalDataset: ...@@ -14,7 +14,7 @@ EvalDataset:
anno_path: annotations/instances_val2017.json anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/ dataset_dir: dataset/coco/
worker_num: 4 worker_num: 0
# preprocess reader in test # preprocess reader in test
EvalReader: EvalReader:
......
...@@ -19,7 +19,7 @@ Distillation: ...@@ -19,7 +19,7 @@ Distillation:
- teacher_conv2d_119.tmp_1 - teacher_conv2d_119.tmp_1
- conv2d_119.tmp_1 - conv2d_119.tmp_1
merge_feed: true merge_feed: true
teacher_model_dir: ./yolov5_inference_model/ teacher_model_dir: ./yolov5s_infer/
teacher_model_filename: model.pdmodel teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams teacher_params_filename: model.pdiparams
...@@ -37,7 +37,7 @@ Quantization: ...@@ -37,7 +37,7 @@ Quantization:
- depthwise_conv2d - depthwise_conv2d
TrainConfig: TrainConfig:
epochs: 1 train_iter: 3000
eval_iter: 1000 eval_iter: 1000
learning_rate: 0.00001 learning_rate: 0.00001
optimizer: SGD optimizer: SGD
......
...@@ -265,8 +265,14 @@ class AutoCompression: ...@@ -265,8 +265,14 @@ class AutoCompression:
_logger.info( _logger.info(
"Calculating the iterations per epoch……(It will take some time)") "Calculating the iterations per epoch……(It will take some time)")
# NOTE:XXX: This way of calculating the iters needs to be improved. # NOTE:XXX: This way of calculating the iters needs to be improved.
iters_per_epoch = len(list(self.train_dataloader())) if self.train_config.epochs:
total_iters = self.train_config.epochs * iters_per_epoch iters_per_epoch = len(list(self.train_dataloader()))
total_iters = self.train_config.epochs * iters_per_epoch
elif self.train_config.train_iter:
total_iters = self.train_config.train_iter
else:
raise RuntimeError(
'train_config must has `epochs` or `train_iter` field.')
config_dict['gmp_config'] = { config_dict['gmp_config'] = {
'stable_iterations': 0, 'stable_iterations': 0,
'pruning_iterations': 0.45 * total_iters, 'pruning_iterations': 0.45 * total_iters,
...@@ -498,7 +504,8 @@ class AutoCompression: ...@@ -498,7 +504,8 @@ class AutoCompression:
def _start_train(self, train_program_info, test_program_info, strategy): def _start_train(self, train_program_info, test_program_info, strategy):
best_metric = -1.0 best_metric = -1.0
for epoch_id in range(self.train_config.epochs): total_epochs = self.train_config.epochs if self.train_config.epochs else 1
for epoch_id in range(total_epochs):
for batch_id, data in enumerate(self.train_dataloader()): for batch_id, data in enumerate(self.train_dataloader()):
np_probs_float, = self._exe.run(train_program_info.program, \ np_probs_float, = self._exe.run(train_program_info.program, \
feed=data, \ feed=data, \
...@@ -551,6 +558,8 @@ class AutoCompression: ...@@ -551,6 +558,8 @@ class AutoCompression:
_logger.warning( _logger.warning(
"Not set eval function, so unable to test accuracy performance." "Not set eval function, so unable to test accuracy performance."
) )
if self.train_config.train_iter and batch_id >= self.train_config.train_iter:
break
if 'unstructure' in self._strategy or self.train_config.sparse_model: if 'unstructure' in self._strategy or self.train_config.sparse_model:
self._pruner.update_params() self._pruner.update_params()
......
...@@ -103,11 +103,24 @@ UnstructurePrune = namedtuple("UnstructurePrune", [ ...@@ -103,11 +103,24 @@ UnstructurePrune = namedtuple("UnstructurePrune", [
UnstructurePrune.__new__.__defaults__ = (None, ) * len(UnstructurePrune._fields) UnstructurePrune.__new__.__defaults__ = (None, ) * len(UnstructurePrune._fields)
### Train ### Train
TrainConfig = namedtuple("Train", [ TrainConfig = namedtuple(
"epochs", "learning_rate", "optimizer", "optim_args", "eval_iter", "Train",
"logging_iter", "origin_metric", "target_metric", "use_fleet", "amp_config", [
"recompute_config", "sharding_config", "sparse_model" "epochs", # Training total epoch
]) "train_iter", # Training total iteration, `epochs` or `train_iter` only need to set one.
"learning_rate",
"optimizer",
"optim_args",
"eval_iter",
"logging_iter",
"origin_metric",
"target_metric",
"use_fleet",
"amp_config",
"recompute_config",
"sharding_config",
"sparse_model",
])
TrainConfig.__new__.__defaults__ = (None, ) * len(TrainConfig._fields) TrainConfig.__new__.__defaults__ = (None, ) * len(TrainConfig._fields)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册