未验证 提交 a620089a 编写于 作者: G Guanghua Yu 提交者: GitHub

support train_iter in act (#1161)

上级 066b7056
......@@ -12,7 +12,6 @@ repos:
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
language_version: python3.9
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0
......
......@@ -34,7 +34,7 @@ Quantization:
- depthwise_conv2d
TrainConfig:
epochs: 1
train_iter: 3000
eval_iter: 1000
learning_rate: 0.00001
optimizer: SGD
......
......@@ -30,7 +30,7 @@ EvalDataset:
use_gt_bbox: True
image_thre: 0.5
worker_num: 2
worker_num: 0
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
......
......@@ -14,7 +14,7 @@ EvalDataset:
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/
worker_num: 4
worker_num: 0
# preprocess reader in test
EvalReader:
......
......@@ -34,7 +34,7 @@ Quantization:
- depthwise_conv2d
TrainConfig:
epochs: 1
train_iter: 3000
eval_iter: 1000
learning_rate: 0.0001
optimizer: SGD
......
......@@ -14,7 +14,7 @@ EvalDataset:
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/
worker_num: 4
worker_num: 0
# preprocess reader in test
EvalReader:
......
......@@ -19,7 +19,7 @@ Distillation:
- teacher_conv2d_119.tmp_1
- conv2d_119.tmp_1
merge_feed: true
teacher_model_dir: ./yolov5_inference_model/
teacher_model_dir: ./yolov5s_infer/
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
......@@ -37,7 +37,7 @@ Quantization:
- depthwise_conv2d
TrainConfig:
epochs: 1
train_iter: 3000
eval_iter: 1000
learning_rate: 0.00001
optimizer: SGD
......
......@@ -265,8 +265,14 @@ class AutoCompression:
_logger.info(
"Calculating the iterations per epoch……(It will take some time)")
# NOTE:XXX: This way of calculating the iters needs to be improved.
iters_per_epoch = len(list(self.train_dataloader()))
total_iters = self.train_config.epochs * iters_per_epoch
if self.train_config.epochs:
iters_per_epoch = len(list(self.train_dataloader()))
total_iters = self.train_config.epochs * iters_per_epoch
elif self.train_config.train_iter:
total_iters = self.train_config.train_iter
else:
raise RuntimeError(
'train_config must has `epochs` or `train_iter` field.')
config_dict['gmp_config'] = {
'stable_iterations': 0,
'pruning_iterations': 0.45 * total_iters,
......@@ -498,7 +504,8 @@ class AutoCompression:
def _start_train(self, train_program_info, test_program_info, strategy):
best_metric = -1.0
for epoch_id in range(self.train_config.epochs):
total_epochs = self.train_config.epochs if self.train_config.epochs else 1
for epoch_id in range(total_epochs):
for batch_id, data in enumerate(self.train_dataloader()):
np_probs_float, = self._exe.run(train_program_info.program, \
feed=data, \
......@@ -551,6 +558,8 @@ class AutoCompression:
_logger.warning(
"Not set eval function, so unable to test accuracy performance."
)
if self.train_config.train_iter and batch_id >= self.train_config.train_iter:
break
if 'unstructure' in self._strategy or self.train_config.sparse_model:
self._pruner.update_params()
......
......@@ -103,11 +103,24 @@ UnstructurePrune = namedtuple("UnstructurePrune", [
UnstructurePrune.__new__.__defaults__ = (None, ) * len(UnstructurePrune._fields)
### Train
TrainConfig = namedtuple("Train", [
"epochs", "learning_rate", "optimizer", "optim_args", "eval_iter",
"logging_iter", "origin_metric", "target_metric", "use_fleet", "amp_config",
"recompute_config", "sharding_config", "sparse_model"
])
TrainConfig = namedtuple(
"Train",
[
"epochs", # Training total epoch
"train_iter", # Training total iteration, `epochs` or `train_iter` only need to set one.
"learning_rate",
"optimizer",
"optim_args",
"eval_iter",
"logging_iter",
"origin_metric",
"target_metric",
"use_fleet",
"amp_config",
"recompute_config",
"sharding_config",
"sparse_model",
])
TrainConfig.__new__.__defaults__ = (None, ) * len(TrainConfig._fields)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册