未验证 提交 060c13ec 编写于 作者: oqqZun1's avatar oqqZun1 提交者: GitHub

add classifier auto_augment demo

上级 400250ed
# PaddleHub 自动数据增强
本示例将展示如何使用PaddleHub搜索最适合数据的数据增强策略,并将其应用到模型训练中。
## 依赖
请预先从pip下载auto-augment软件包
```
pip install auto-augment
```
## auto-augment简述
auto-augment软件包目前支持Paddle的图像分类任务和物体检测任务。
应用时分成搜索(search)和训练(train)两个阶段
**搜索阶段在预置模型上对不同算子的组合进行策略搜索,输出最优数据增强调度策略组合**
**训练阶段在特定模型上应用最优调度数据增强策略组合 **
详细关于auto-augment的使用及benchmark可参考auto_augment/doc里的readme
## 支持任务
目前auto-augment支持paddlhub的图像分类任务。
后续会扩充到其他任务
## 图像分类任务
### 参数配置
参数配置支持yaml格式描述及json格式描述,项目中仅提供yaml格式配置模板。模板统一于configs/路径下
用户可配置参数分为task_config(任务配置),data_config(数据配置), resource_config(资源配置),algo_config(算法配置), search_space(搜索空间配置)。
#### task_config(任务配置)
​ 任务配置细节,包括任务类型及模型细节
​ 具体字段如下:
​ run_mode: ["ray", "automl_service"], #表示后端采用服务,目前支持单机ray框架
​ work_space: 用户工作空间
​ task_type: ["classifier"] #任务类型,目前PaddleHub支持图像分类单标签,需要请使用物体检测单标签任务的增强请参考auto_augment/doc
​ classifier: 具体任务类型的配置细节,
##### classifier任务配置细节
- model_name: paddlehub模型名称
- epochs: int, 任务搜索轮数, **必填** , 该参数需要特殊指定
- Input_size: 模型输入尺寸
- scale_size: 数据预处理尺寸
- no_cache_image: 不缓存数据, 默认False
- use_class_map: 使用label_list 映射
#### data_config(数据配置)
数据配置支持多种格式输入, 包括图像分类txt标注格式, 物体检测voc标注格式, 物体检测coco标注格式.
- train_img_prefix:str. 训练集数据路径前缀
- train_ann_file:str, 训练集数据描述文件,
- val_img_prefix:str, 验证集数据路径前缀
- val_ann_file:str,验证集数据描述文件
- label_list:str, 标签文件
- delimiter: "," 数据描述文件采用的分隔符
#### resource_config(资源配置)
- gpu:float, 表示每个搜索进程的gpu分配资源,run_mode=="ray"模式下支持小数分配
- cpu: float, 表示每个搜索进程的cpu分配资源,run_mode=="ray"模式下支持小数分配
#### algo_config(算法配置)
算法配置目前仅支持PBA,后续会进一步拓展。
##### PBA配置
- algo_name: str, ["PBA"], 搜索算法
- algo_param:
- perturbation_interval: 搜索扰动周期
- num_samples:搜索进程数
#### search_space(搜索空间配置)
搜索空间定义, 策略搜索阶段必填, 策略应用训练会忽略。
- operators_repeat: int,默认1, 表示搜索算子的重复次数。
- operator_space: 搜索的算子空间
1. 自定义算子模式:
htype: str, ["choice"] 超参类型,目前支持choice枚举
value: list, [0,0.5,1] 枚举数据
![image-20200707162627074](./doc/operators.png)
2. 缩略版算子模式:
用户只需要指定需要搜索的算子,prob, magtitue搜索空间为系统默认配置,为0-1之间。
![image-20200707162709253](./doc/short_operators.png)
支持1,2模式混合定议
##### 图像分类算子
["Sharpness", "Rotate", "Invert", "Brightness", "Cutout", "Equalize","TranslateY", "AutoContrast", "Color","TranslateX", "Solarize", "ShearX","Contrast", "Posterize", "ShearY", "FlipLR"]
### 搜索阶段
用于数据增强策略的搜索
### 训练阶段
在训练中应用搜索出来的数据增强策略
### 示例demo
#### Flower数据组织
```
cd PaddleHub/demo/autaug/
mkdir -p ./dataset
cd dataset
wget https://bj.bcebos.com/paddlehub-dataset/flower_photos.tar.gz
tar -xvf flower_photos.tar.gz
```
#### 搜索流程
```
cd PaddleHub/demo/autaug/
bash search.sh
# 结果会以json形式dump到workspace中,用户可利用这个json文件进行训练
```
#### 训练阶段
```
cd PaddleHub/demo/autaug/
bash train.sh
```
# -*- coding: utf-8 -*-
#*******************************************************************************
#
# Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved
#
#*******************************************************************************
"""
Authors: lvhaijun01@baidu.com
Date: 2020-11-24 20:43
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import six
import os
from typing import Dict, List, Optional, Union, Tuple
from auto_augment.autoaug.utils import log
import logging
logger = log.get_logger(level=logging.INFO)
import auto_augment
auto_augment_path = auto_augment.__file__
class HubFitterClassifer(object):
"""Trains an instance of the Model class."""
def __init__(self, hparams: dict) -> None:
"""
定义分类任务的数据、模型
Args:
hparams:
"""
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags(
# enable GC to save memory
FLAGS_fraction_of_gpu_memory_to_use=hparams.resource_config.gpu,
)
import paddle
import paddlehub as hub
from paddlehub_utils.trainer import CustomTrainer
from paddlehub_utils.reader import _init_loader
# todo now does not support fleet distribute training
# from paddle.fluid.incubate.fleet.base import role_maker
# from paddle.fluid.incubate.fleet.collective import fleet
# role = role_maker.PaddleCloudRoleMaker(is_collective=True)
# fleet.init(role)
logger.info("classficiation data augment search begin")
self.hparams = hparams
# param compatible
self._fit_param(show=True)
paddle.disable_static(paddle.CUDAPlace(paddle.distributed.get_rank()))
train_dataset, eval_dataset = _init_loader(self.hparams)
model = hub.Module(name=hparams["task_config"]["classifier"]["model_name"], label_list=self.class_to_id_dict.keys(), load_checkpoint=None)
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
trainer = CustomTrainer(
model=model,
optimizer=optimizer,
checkpoint_dir='img_classification_ckpt')
self.model = model
self.optimizer = optimizer
trainer.init_train_and_eval(
train_dataset,
epochs=100,
batch_size=32,
eval_dataset=eval_dataset,
save_interval=1)
self.trainer = trainer
def _fit_param(self, show: bool = False) -> None:
"""
param fit
Args:
hparams:
Returns:
"""
hparams = self.hparams
self._get_label_info(hparams)
def _get_label_info(self, hparams: dict) -> None:
"""
Args:
hparams:
Returns:
"""
from paddlehub_utils.reader import _read_classes
data_config = hparams.data_config
label_list = data_config.label_list
if os.path.isfile(label_list):
class_to_id_dict = _read_classes(label_list)
else:
assert 0, "label_list:{} not exist".format(label_list)
self.num_classes = len(class_to_id_dict)
self.class_to_id_dict = class_to_id_dict
def reset_config(self, new_hparams: dict) -> None:
"""
reset config, used by search stage
Args:
new_hparams:
Returns:
"""
self.hparams = new_hparams
self.trainer.train_loader.dataset.reset_policy(
new_hparams.search_space)
return None
def save_model(self, checkpoint_dir: str, step: Optional[str] = None) -> str:
"""Dumps model into the backup_dir.
Args:
step: If provided, creates a checkpoint with the given step
number, instead of overwriting the existing checkpoints.
"""
checkpoint_path = os.path.join(checkpoint_dir,
'epoch') + '-' + str(step)
logger.info('Saving model checkpoint to {}'.format(checkpoint_path))
self.trainer.save_model(os.path.join(checkpoint_path, "checkpoint"))
return checkpoint_path
def extract_model_spec(self, checkpoint_path: str) -> None:
"""Loads a checkpoint with the architecture structure stored in the name."""
ckpt_path = os.path.join(checkpoint_path, "checkpoint")
self.trainer.load_model(ckpt_path)
logger.info(
'Loaded child model checkpoint from {}'.format(checkpoint_path))
def eval_child_model(self, mode: str, pass_id: int = 0) -> dict:
"""Evaluate the child model.
Args:
model: image model that will be evaluated.
data_loader: dataset object to extract eval data from.
mode: will the model be evalled on train, val or test.
Returns:
Accuracy of the model on the specified dataset.
"""
eval_loader = self.trainer.eval_loader
res = self.trainer.evaluate_process(eval_loader)
top1_acc = res["metrics"]["acc"]
if mode == "val":
return {"val_acc": top1_acc}
elif mode == "test":
return {"test_acc": top1_acc}
else:
raise NotImplementedError
def train_one_epoch(self, pass_id: int) -> dict:
"""
Args:
model:
train_loader:
optimizer:
Returns:
"""
from paddlehub.utils.utils import Timer
batch_sampler = self.trainer.batch_sampler
train_loader = self.trainer.train_loader
steps_per_epoch = len(batch_sampler)
task_config = self.hparams.task_config
task_type = task_config.task_type
epochs = task_config.classifier.epochs
timer = Timer(steps_per_epoch * epochs)
timer.start()
self.trainer.train_one_epoch(
loader=train_loader,
timer=timer,
current_epoch=pass_id,
epochs=epochs,
log_interval=10,
steps_per_epoch=steps_per_epoch)
return {"train_acc": 0}
def _run_training_loop(self, curr_epoch: int) -> dict:
"""Trains the model `m` for one epoch."""
start_time = time.time()
train_acc = self.train_one_epoch(curr_epoch)
logger.info(
'Epoch:{} time(min): {}'.format(
curr_epoch,
(time.time() - start_time) / 60.0))
return train_acc
def _compute_final_accuracies(self, iteration: int) -> dict:
"""Run once training is finished to compute final test accuracy."""
task_config = self.hparams.task_config
task_type = task_config.task_type
if (iteration >= task_config[task_type].epochs - 1):
test_acc = self.eval_child_model('test', iteration)
pass
else:
test_acc = {"test_acc": 0}
logger.info('Test acc: {}'.format(test_acc))
return test_acc
def run_model(self, epoch: int) -> dict:
"""Trains and evalutes the image model."""
self._fit_param()
train_acc = self._run_training_loop(epoch)
valid_acc = self.eval_child_model(mode="val", pass_id=epoch)
logger.info('valid acc: {}'.format(
valid_acc))
all_metric = {}
all_metric.update(train_acc)
all_metric.update(valid_acc)
return all_metric
# -*- coding: utf-8 -*-
#*******************************************************************************
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
#*******************************************************************************
"""
Authors: lvhaijun01@baidu.com
Date: 2019-09-17 14:15
"""
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# *******************************************************************************
#
# Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved
#
# *******************************************************************************
"""
Authors: lvhaijun01@baidu.com
Date: 2019-06-30 00:10
"""
import re
import numpy as np
from typing import Dict, List, Optional, Union, Tuple
import six
import cv2
import os
import paddle
import paddlehub.vision.transforms as transforms
from PIL import ImageFile
from auto_augment.autoaug.transform.autoaug_transform import AutoAugTransform
ImageFile.LOAD_TRUNCATED_IMAGES = True
__imagenet_stats = {'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]}
class PbaAugment(object):
"""
pytorch 分类 PbaAugment transform
"""
def __init__(
self,
input_size: int = 224,
scale_size: int = 256,
normalize: Optional[list] = None,
pre_transform: bool = True,
stage: str = "search",
**kwargs) -> None:
"""
Args:
input_size:
scale_size:
normalize:
pre_transform:
**kwargs:
"""
if normalize is None:
normalize = {
'mean': [
0.485, 0.456, 0.406], 'std': [
0.229, 0.224, 0.225]}
policy = kwargs["policy"]
assert stage in ["search", "train"]
train_epochs = kwargs["hp_policy_epochs"]
self.auto_aug_transform = AutoAugTransform.create(
policy, stage=stage, train_epochs=train_epochs)
#self.auto_aug_transform = PbtAutoAugmentClassiferTransform(conf)
if pre_transform:
self.pre_transform = transforms.Resize(input_size)
self.post_transform = transforms.Compose(
transforms=[
transforms.Permute(),
transforms.Normalize(**normalize, channel_first=True)
],
channel_first = False
)
self.cur_epoch = 0
def set_epoch(self, indx: int) -> None:
"""
Args:
indx:
Returns:
"""
self.auto_aug_transform.set_epoch(indx)
def reset_policy(self, new_hparams: dict) -> None:
"""
Args:
new_hparams:
Returns:
"""
self.auto_aug_transform.reset_policy(new_hparams)
def __call__(self, img: np.ndarray):
"""
Args:
img: PIL image
Returns:
"""
# tensform resize
if self.pre_transform:
img = self.pre_transform(img)
img = self.auto_aug_transform.apply(img)
img = img.astype(np.uint8)
img = self.post_transform(img)
return img
class PicRecord(object):
"""
PicRecord
"""
def __init__(self, row: list) -> None:
"""
Args:
row:
"""
self._data = row
@property
def sub_path(self) -> str:
"""
Returns:
"""
return self._data[0]
@property
def label(self) -> str:
"""
Returns:
"""
return self._data[1]
class PicReader(paddle.io.Dataset):
"""
PicReader
"""
def __init__(
self,
root_path: str,
list_file: str,
meta: bool = False,
transform: Optional[callable] = None,
class_to_id_dict: Optional[dict] = None,
cache_img: bool = False,
**kwargs) -> None:
"""
Args:
root_path:
list_file:
meta:
transform:
class_to_id_dict:
cache_img:
**kwargs:
"""
self.root_path = root_path
self.list_file = list_file
self.transform = transform
self.meta = meta
self.class_to_id_dict = class_to_id_dict
self.train_type = kwargs["conf"].get("train_type", "single_label")
self.class_num = kwargs["conf"].get("class_num", 0)
self._parse_list(**kwargs)
self.cache_img = cache_img
self.cache_img_buff = dict()
if self.cache_img:
self._get_all_img(**kwargs)
def _get_all_img(self, **kwargs) -> None:
"""
缓存图片进行预resize, 减少内存占用
Returns:
"""
scale_size = kwargs.get("scale_size", 256)
for idx in range(len(self)):
record = self.pic_list[idx]
relative_path = record.sub_path
if self.root_path is not None:
image_path = os.path.join(self.root_path, relative_path)
else:
image_path = relative_path
try:
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (scale_size, scale_size))
self.cache_img_buff[image_path] = img
except BaseException:
print("img_path:{} can not by cv2".format(
image_path).format(image_path))
pass
def _load_image(self, directory: str) -> np.ndarray:
"""
Args:
directory:
Returns:
"""
if not self.cache_img:
img = cv2.imread(directory, cv2.IMREAD_COLOR).astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = Image.open(directory).convert('RGB')
else:
if directory in self.cache_img_buff:
img = self.cache_img_buff[directory]
else:
img = cv2.imread(directory, cv2.IMREAD_COLOR).astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = Image.open(directory).convert('RGB')
return img
def _parse_list(self, **kwargs) -> None:
"""
Args:
**kwargs:
Returns:
"""
delimiter = kwargs.get("delimiter", " ")
self.pic_list = []
with open(self.list_file) as f:
lines = f.read().splitlines()
print(
"PicReader:: found {} picture in `{}'".format(
len(lines), self.list_file))
for i, line in enumerate(lines):
record = re.split(delimiter, line)
# record = line.split()
assert len(record) == 2, "length of record is not 2!"
if not os.path.splitext(record[0])[1]:
# 适配线上分类数据转无后缀的情况
record[0] = record[0] + ".jpg"
# 线上单标签情况兼容多标签,后续需去除
record[1] = re.split(",", record[1])[0]
self.pic_list.append(PicRecord(record))
def __getitem__(self, index: int):
"""
Args:
index:
Returns:
"""
record = self.pic_list[index]
return self.get(record)
def get(self, record: PicRecord) -> tuple:
"""
Args:
record:
Returns:
"""
relative_path = record.sub_path
if self.root_path is not None:
image_path = os.path.join(self.root_path, relative_path)
else:
image_path = relative_path
img = self._load_image(image_path)
# print("org img sum:{}".format(np.sum(np.asarray(img))))
process_data = self.transform(img)
if self.train_type == "single_label":
if self.class_to_id_dict:
label = self.class_to_id_dict[record.label]
else:
label = int(record.label)
elif self.train_type == "multi_labels":
label_tensor = np.zeros((1, self.class_num))
for label in record.label.split(","):
label_tensor[0, int(label)] = 1
label_tensor = np.squeeze(label_tensor)
label = label_tensor
if self.meta:
return process_data, label, relative_path
else:
return process_data, label
def __len__(self) -> int:
"""
Returns:
"""
return len(self.pic_list)
def set_meta(self, meta: bool) -> None:
"""
Args:
meta:
Returns:
"""
self.meta = meta
def set_epoch(self, epoch: int) -> None:
"""
Args:
epoch:
Returns:
"""
if self.transform is not None:
self.transform.set_epoch(epoch)
# only use in search
def reset_policy(self, new_hparams: dict) -> None:
"""
Args:
new_hparams:
Returns:
"""
if self.transform is not None:
self.transform.reset_policy(new_hparams)
def _parse(value: str, function: callable, fmt: str) -> None:
"""
Parse a string into a value, and format a nice ValueError if it fails.
Returns `function(value)`.
Any `ValueError` raised is catched and a new `ValueError` is raised
with message `fmt.format(e)`, where `e` is the caught `ValueError`.
"""
try:
return function(value)
except ValueError as e:
six.raise_from(ValueError(fmt.format(e)), None)
def _read_classes(csv_file: str) -> dict:
""" Parse the classes file.
"""
result = {}
with open(csv_file) as csv_reader:
for line, row in enumerate(csv_reader):
try:
class_name = row.strip()
# print(class_id, class_name)
except ValueError:
six.raise_from(
ValueError(
'line {}: format should be \'class_name\''.format(line)),
None)
class_id = _parse(
line,
int,
'line {}: malformed class ID: {{}}'.format(line))
if class_name in result:
raise ValueError(
'line {}: duplicate class name: \'{}\''.format(
line, class_name))
result[class_name] = class_id
return result
def _init_loader(hparams: dict, TrainTransform=None) -> tuple:
"""
Args:
hparams:
Returns:
"""
train_data_root = hparams.data_config.train_img_prefix
val_data_root = hparams.data_config.val_img_prefix
train_list = hparams.data_config.train_ann_file
val_list = hparams.data_config.val_ann_file
input_size = hparams.task_config.classifier.input_size
scale_size = hparams.task_config.classifier.scale_size
search_space = hparams.search_space
search_space["task_type"] = hparams.task_config.task_type
epochs = hparams.task_config.classifier.epochs
no_cache_img = hparams.task_config.classifier.get("no_cache_img", False)
normalize = {
'mean': [
0.485, 0.456, 0.406], 'std': [
0.229, 0.224, 0.225]}
if TrainTransform is None:
TrainTransform = PbaAugment(
input_size=input_size,
scale_size=scale_size,
normalize=normalize,
policy=search_space,
hp_policy_epochs=epochs,
)
delimiter = hparams.data_config.delimiter
kwargs = dict(
conf=hparams,
delimiter=delimiter
)
if hparams.task_config.classifier.use_class_map:
class_to_id_dict = _read_classes(label_list=hparams.data_config.label_list)
else:
class_to_id_dict = None
train_data = PicReader(
root_path=train_data_root,
list_file=train_list,
transform=TrainTransform,
class_to_id_dict=class_to_id_dict,
cache_img=not no_cache_img,
**kwargs)
val_data = PicReader(
root_path=val_data_root,
list_file=val_list,
transform=transforms.Compose(
transforms=[
transforms.Resize(
(224,
224)),
transforms.Permute(),
transforms.Normalize(
**normalize, channel_first=True)],
channel_first = False),
class_to_id_dict=class_to_id_dict,
cache_img=not no_cache_img,
**kwargs)
return train_data, val_data
# -*- coding: utf-8 -*-
#*******************************************************************************
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
#*******************************************************************************
"""
Authors: lvhaijun01@baidu.com
Date: 2020-11-24 20:46
"""
from paddlehub.finetune.trainer import Trainer
import os
from collections import defaultdict
import paddle
from paddle.distributed import ParallelEnv
from paddlehub.utils.log import logger
from paddlehub.utils.utils import Timer
class CustomTrainer(Trainer):
def __init__(self, **kwargs) -> None:
super(CustomTrainer, self).__init__(**kwargs)
def init_train_and_eval(self,
train_dataset: paddle.io.Dataset,
epochs: int = 1,
batch_size: int = 1,
num_workers: int = 0,
eval_dataset: paddle.io.Dataset = None,
log_interval: int = 10,
save_interval: int = 10) -> None:
self.batch_sampler, self.train_loader = self.init_train(train_dataset, batch_size, num_workers)
self.eval_loader = self.init_evaluate(eval_dataset, batch_size, num_workers)
def init_train(self,
train_dataset: paddle.io.Dataset,
batch_size: int = 1,
num_workers: int = 0) -> tuple:
use_gpu = True
place = paddle.CUDAPlace(ParallelEnv().dev_id) if use_gpu else paddle.CPUPlace()
paddle.disable_static(place)
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
loader = paddle.io.DataLoader(
train_dataset, batch_sampler=batch_sampler, places=place, num_workers=num_workers, return_list=True)
return batch_sampler, loader
def train_one_epoch(self, loader: paddle.io.DataLoader, timer: Timer, current_epoch: int, epochs: int, log_interval: int, steps_per_epoch: int) -> None:
avg_loss = 0
avg_metrics = defaultdict(int)
self.model.train()
for batch_idx, batch in enumerate(loader):
loss, metrics = self.training_step(batch, batch_idx)
self.optimizer_step(current_epoch, batch_idx, self.optimizer, loss)
self.optimizer_zero_grad(current_epoch, batch_idx, self.optimizer)
# calculate metrics and loss
avg_loss += loss.numpy()[0]
for metric, value in metrics.items():
avg_metrics[metric] += value.numpy()[0]
timer.count()
if (batch_idx + 1) % log_interval == 0 and self.local_rank == 0:
lr = self.optimizer.get_lr()
avg_loss /= log_interval
if self.use_vdl:
self.log_writer.add_scalar(tag='TRAIN/loss', step=timer.current_step, value=avg_loss)
print_msg = 'Epoch={}/{}, Step={}/{}'.format(current_epoch, epochs, batch_idx + 1,
steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
for metric, value in avg_metrics.items():
value /= log_interval
if self.use_vdl:
self.log_writer.add_scalar(
tag='TRAIN/{}'.format(metric), step=timer.current_step, value=value)
print_msg += ' {}={:.4f}'.format(metric, value)
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(lr, timer.timing, timer.eta)
logger.train(print_msg)
avg_loss = 0
avg_metrics = defaultdict(int)
def train(self,
train_dataset: paddle.io.Dataset,
epochs: int = 1,
batch_size: int = 1,
num_workers: int = 0,
eval_dataset: paddle.io.Dataset = None,
log_interval: int = 10,
save_interval: int = 10):
'''
Train a model with specific config.
Args:
train_dataset(paddle.io.Dataset) : Dataset to train the model
epochs(int) : Number of training loops, default is 1.
batch_size(int) : Batch size of per step, default is 1.
num_workers(int) : Number of subprocess to load data, default is 0.
eval_dataset(paddle.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will
execute evaluate function every `save_interval` epochs.
log_interval(int) : Log the train infomation every `log_interval` steps.
save_interval(int) : Save the checkpoint every `save_interval` epochs.
'''
batch_sampler, loader = self.init_train(train_dataset, batch_size, num_workers)
steps_per_epoch = len(batch_sampler)
timer = Timer(steps_per_epoch * epochs)
timer.start()
for i in range(epochs):
loader.dataset.set_epoch(epochs)
self.current_epoch += 1
self.train_one_epoch(loader, timer, self.current_epoch, epochs, log_interval, steps_per_epoch)
# todo, why paddlehub put save, eval in batch?
if self.current_epoch % save_interval == 0 and self.local_rank == 0:
if eval_dataset:
result = self.evaluate(eval_dataset, batch_size, num_workers)
eval_loss = result.get('loss', None)
eval_metrics = result.get('metrics', {})
if self.use_vdl:
if eval_loss:
self.log_writer.add_scalar(tag='EVAL/loss', step=timer.current_step, value=eval_loss)
for metric, value in eval_metrics.items():
self.log_writer.add_scalar(
tag='EVAL/{}'.format(metric), step=timer.current_step, value=value)
if not self.best_metrics or self.compare_metrics(self.best_metrics, eval_metrics):
self.best_metrics = eval_metrics
best_model_path = os.path.join(self.checkpoint_dir, 'best_model')
self.save_model(best_model_path)
self._save_metrics()
metric_msg = [
'{}={:.4f}'.format(metric, value) for metric, value in self.best_metrics.items()
]
metric_msg = ' '.join(metric_msg)
logger.eval('Saving best model to {} [best {}]'.format(best_model_path, metric_msg))
self._save_checkpoint()
def init_evaluate(self, eval_dataset: paddle.io.Dataset, batch_size: int, num_workers: int) -> paddle.io.DataLoader:
use_gpu = True
place = paddle.CUDAPlace(ParallelEnv().dev_id) if use_gpu else paddle.CPUPlace()
paddle.disable_static(place)
batch_sampler = paddle.io.DistributedBatchSampler(
eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
loader = paddle.io.DataLoader(
eval_dataset, batch_sampler=batch_sampler, places=place, num_workers=num_workers, return_list=True)
return loader
def evaluate_process(self, loader: paddle.io.DataLoader) -> dict:
self.model.eval()
avg_loss = num_samples = 0
sum_metrics = defaultdict(int)
avg_metrics = defaultdict(int)
for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx)
loss = result.get('loss', None)
metrics = result.get('metrics', {})
bs = batch[0].shape[0]
num_samples += bs
if loss:
avg_loss += loss.numpy()[0] * bs
for metric, value in metrics.items():
sum_metrics[metric] += value.numpy()[0] * bs
# print avg metrics and loss
print_msg = '[Evaluation result]'
if loss:
avg_loss /= num_samples
print_msg += ' avg_loss={:.4f}'.format(avg_loss)
for metric, value in sum_metrics.items():
avg_metrics[metric] = value / num_samples
print_msg += ' avg_{}={:.4f}'.format(metric, avg_metrics[metric])
logger.eval(print_msg)
if loss:
return {'loss': avg_loss, 'metrics': avg_metrics}
return {'metrics': avg_metrics}
def evaluate(self, eval_dataset: paddle.io.Dataset, batch_size: int = 1, num_workers: int = 0) -> dict:
'''
Run evaluation and returns metrics.
Args:
eval_dataset(paddle.io.Dataset) : The validation dataset
batch_size(int) : Batch size of per step, default is 1.
num_workers(int) : Number of subprocess to load data, default is 0.
'''
loader = self.init_evaluate(eval_dataset, batch_size, num_workers)
res = self.evaluate_process(loader)
return res
task_config:
run_mode: "ray"
workspace: "./work_dirs/pbt_hub_classifer/test_autoaug"
task_type: "classifier"
classifier:
model_name: "resnet50_vd_imagenet_ssld"
epochs: 100
input_size: 224
scale_size: 256
no_cache_img: false
use_class_map: false
data_config:
train_img_prefix: "./dataset/flower_photos"
train_ann_file: "./dataset/flower_photos/train_list.txt"
val_img_prefix: "./dataset/flower_photos"
val_ann_file: "./dataset/flower_photos/validate_list.txt"
label_list: "./dataset/flower_photos/label_list.txt"
delimiter: " "
resource_config:
gpu: 0.4
cpu: 1
algo_config:
algo_name: "PBA"
algo_param:
perturbation_interval: 3
num_samples: 8
search_space:
operator_space:
-
name: Sharpness
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: Rotate
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: Invert
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: Brightness
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: Cutout
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: Equalize
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: TranslateY
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: AutoContrast
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: Color
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: TranslateX
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: Solarize
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: ShearX
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: Contrast
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: Posterize
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: ShearY
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
-
name: FlipLR
prob:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
magtitude:
htype: choice
value: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
from auto_augment.autoaug.experiment.experiment import AutoAugExperiment
from auto_augment.autoaug.utils.yaml_config import get_config
from hub_fitter import HubFitterClassifer
import os
import argparse
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument("--config",help="config file",)
parser.add_argument("--workspace",default=None, help="work_space",)
def main():
search_test()
def search_test():
args = parser.parse_args()
config = args.config
config = get_config(config, show=True)
task_config = config.task_config
data_config = config.data_config
resource_config = config.resource_config
algo_config = config.algo_config
search_space = config.get("search_space", None)
if args.workspace is not None:
task_config["workspace"] = args.workspace
workspace = task_config["workspace"]
# 算法,任务,资源,数据,搜索空间(optional)配置导入,
exper = AutoAugExperiment.create(
algo_config=algo_config,
task_config=task_config,
resource_config=resource_config,
data_config=data_config,
search_space=search_space,
fitter=HubFitterClassifer
)
result = exper.search() # 开始搜索任务
policy = result.get_best_policy() # 最佳策略获取, policy格式见 搜索结果应用格式
print("policy is:{}".format(policy))
dump_path = os.path.join(workspace, "auto_aug_config.json")
result.dump_best_policy(
path=dump_path)
if __name__ == "__main__":
main()
#!/usr/bin/env bash
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
config="./pba_classifier_example.yaml"
workspace="./work_dirs//autoaug_flower_mobilenetv2"
# workspace工作空间需要初始化
rm -rf ${workspace}
mkdir -p ${workspace}
CUDA_VISIBLE_DEVICES=0,1 python -u search.py \
--config=${config} \
--workspace=${workspace} 2>&1 | tee -a ${workspace}/log.txt
# -*- coding: utf-8 -*-
#*******************************************************************************
#
# Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved
#
#*******************************************************************************
"""
Authors: lvhaijun01@baidu.com
Date: 2020-11-26 20:57
"""
from auto_augment.autoaug.utils.yaml_config import get_config
from hub_fitter import HubFitterClassifer
import os
import argparse
import logging
import paddlehub as hub
import paddle
import paddlehub.vision.transforms as transforms
from paddlehub_utils.reader import _init_loader, PbaAugment
from paddlehub_utils.reader import _read_classes
from paddlehub_utils.trainer import CustomTrainer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument("--config",help="config file",)
parser.add_argument("--workspace",default=None, help="work_space",)
parser.add_argument("--policy",default=None, help="data aug policy",)
if __name__ == '__main__':
args = parser.parse_args()
config = args.config
config = get_config(config, show=True)
task_config = config.task_config
data_config = config.data_config
resource_config = config.resource_config
algo_config = config.algo_config
input_size = task_config.classifier.input_size
scale_size = task_config.classifier.scale_size
normalize = {
'mean': [
0.485, 0.456, 0.406], 'std': [
0.229, 0.224, 0.225]}
epochs = task_config.classifier.epochs
policy = args.policy
if policy is None:
print("use normal train transform")
TrainTransform = transforms.Compose(
transforms=[
transforms.Resize(
(input_size,
input_size)),
transforms.Permute(),
transforms.Normalize(
**normalize, channel_first=True)],
channel_first = False)
else:
TrainTransform = PbaAugment(
input_size=input_size,
scale_size=scale_size,
normalize=normalize,
policy=policy,
hp_policy_epochs=epochs,
stage="train"
)
train_dataset, eval_dataset = _init_loader(config, TrainTransform=TrainTransform)
class_to_id_dict = _read_classes(config.data_config.label_list)
model = hub.Module(name=config.task_config.classifier.model_name, label_list=class_to_id_dict.keys(), load_checkpoint=None)
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
trainer = CustomTrainer(
model=model,
optimizer=optimizer,
checkpoint_dir='img_classification_ckpt')
trainer.train(train_dataset, epochs=epochs, batch_size=32, eval_dataset=eval_dataset, save_interval=10)
#!/usr/bin/env bash
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
config="./pba_classifier_example.yaml"
workspace="./work_dirs//autoaug_flower_mobilenetv2"
# workspace工作空间需要初始化
mkdir -p ${workspace}
policy=./work_dirs//autoaug_flower_mobilenetv2/auto_aug_config.json
CUDA_VISIBLE_DEVICES=0,1 python train.py \
--config=${config} \
--policy=${policy} \
--workspace=${workspace} 2>&1 | tee -a ${workspace}/log.txt
......@@ -111,13 +111,18 @@ class Trainer(object):
logger.info('PaddleHub model checkpoint loaded. current_epoch={} [{}]'.format(
self.current_epoch, metric_msg))
model_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch))
self.load_model(model_path)
def load_model(self, load_dir: str):
"""load model"""
# load model checkpoint
model_params_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch), 'model.pdparams')
model_params_path = os.path.join(load_dir, 'model.pdparams')
state_dict = paddle.load(model_params_path)
self.model.set_state_dict(state_dict)
# load optimizer checkpoint
optim_params_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch), 'model.pdopt')
optim_params_path = os.path.join(load_dir, 'model.pdopt')
state_dict = paddle.load(optim_params_path)
self.optimizer.set_state_dict(state_dict)
......
......@@ -29,8 +29,9 @@ class Compose:
Args:
transforms(callmethod) : The method of preprocess images.
to_rgb(bool): Whether to transform the input from BGR mode to RGB mode, default is False.
channel_first(bool): whether to permute image from channel laste to channel first
"""
def __init__(self, transforms: Callable, to_rgb: bool = False):
def __init__(self, transforms: Callable, to_rgb: bool = False, channel_first: bool = True):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
if len(transforms) < 1:
......@@ -38,6 +39,7 @@ class Compose:
'must be equal or larger than 1!')
self.transforms = transforms
self.to_rgb = to_rgb
self.channel_first = channel_first
def __call__(self, im: Union[np.ndarray, str]):
if isinstance(im, str):
......@@ -51,10 +53,19 @@ class Compose:
for op in self.transforms:
im = op(im)
im = F.permute(im)
if self.channel_first:
im = F.permute(im)
return im
class Permute:
"""
Repermute the input image from [H, W, C] to [C, H, W].
"""
def __init__(self):
pass
def __call__(self, im):
im = F.permute(im)
return im
class RandomHorizontalFlip:
"""
......@@ -211,10 +222,12 @@ class Normalize:
Args:
mean(list): Mean value for normalization.
std(list): Standard deviation for normalization.
channel_first(bool): im channel firest or last
"""
def __init__(self, mean: list = [0.5, 0.5, 0.5], std: list = [0.5, 0.5, 0.5]):
def __init__(self, mean: list = [0.5, 0.5, 0.5], std: list = [0.5, 0.5, 0.5], channel_first: bool = False):
self.mean = mean
self.std = std
self.channel_first = channel_first
if not (isinstance(self.mean, list) and isinstance(self.std, list)):
raise ValueError("{}: input type is invalid.".format(self))
from functools import reduce
......@@ -222,8 +235,12 @@ class Normalize:
raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, im):
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if not self.channel_first:
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
else:
mean = np.array(self.mean)[:, np.newaxis, np.newaxis]
std = np.array(self.std)[:, np.newaxis, np.newaxis]
im = F.normalize(im, mean, std)
return im
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册