未验证 提交 23104797 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #656 from flyseaworld/develop

add multilabel feature based develop
......@@ -82,6 +82,7 @@ PaddleClas is a toolset for image classification tasks prepared for the industry
- Advanced tutorials
- [Knowledge distillation](./docs/en/advanced_tutorials/distillation/distillation_en.md)
- [Data augmentation](./docs/en/advanced_tutorials/image_augmentation/ImageAugment_en.md)
- [Multilabel classification](./docs/en/advanced_tutorials/multilabel/multilabel_en.md)
- Applications
- [Transfer learning](./docs/en/application/transfer_learning_en.md)
- [Pretrained model with 100,000 categories](./docs/en/application/transfer_learning_en.md)
......
......@@ -83,6 +83,7 @@
- 高阶使用
- [知识蒸馏](./docs/zh_CN/advanced_tutorials/distillation/distillation.md)
- [数据增广](./docs/zh_CN/advanced_tutorials/image_augmentation/ImageAugment.md)
- [多标签分类](./docs/zh_CN/advanced_tutorials/multilabel/multilabel.md)
- 特色拓展应用
- [迁移学习](./docs/zh_CN/application/transfer_learning.md)
- [10万类图像分类预训练模型](./docs/zh_CN/application/transfer_learning.md)
......
mode: 'train'
ARCHITECTURE:
name: 'ResNet50_vd'
pretrained_model: "./pretrained/ResNet50_vd_pretrained"
model_save_dir: "./output/"
classes_num: 33
total_images: 17463
save_interval: 1
validate: True
valid_interval: 1
epochs: 10
topk: 1
image_shape: [3, 224, 224]
multilabel: True
use_mix: False
ls_epsilon: 0.1
LEARNING_RATE:
function: 'Cosine'
params:
lr: 0.07
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.000070
TRAIN:
batch_size: 256
num_workers: 4
file_list: "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_train_list.txt"
data_dir: "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
mix:
- MixupOperator:
alpha: 0.2
VALID:
batch_size: 64
num_workers: 4
file_list: "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_test_list.txt"
data_dir: "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
\ No newline at end of file
......@@ -6,4 +6,4 @@ advanced_tutorials
image_augmentation/index
distillation/index
multilabel/index
Multilabel Classification
================================
.. toctree::
:maxdepth: 3
multilabel.md
\ No newline at end of file
# Multilabel classification quick start
Based on the [NUS-WIDE-SCENE](https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html) dataset which is a subset of NUS-WIDE dataset, you can experience multilabel of PaddleClas, include training, evaluation and prediction. Please refer to [Installation](install.md) to install at first.
## Preparation
* Enter PaddleClas directory
```
cd path_to_PaddleClas
```
* Create and enter `dataset/NUS-WIDE-SCENE` directory, download and decompress NUS-WIDE-SCENE dataset
```shell
mkdir dataset/NUS-WIDE-SCENE
cd dataset/NUS-WIDE-SCENE
wget https://paddle-imagenet-models-name.bj.bcebos.com/data/NUS-SCENE-dataset.tar
tar -xf NUS-SCENE-dataset.tar
```
* Return `PaddleClas` root home
```
cd ../../
```
## Environment
### Download pretrained model
You can use the following commands to download the pretrained model of ResNet50_vd.
```bash
mkdir pretrained
cd pretrained
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams
cd ../
```
## Training
```shell
export CUDA_VISIBLE_DEVICES=0
python -m paddle.distributed.launch \
--gpus="0" \
tools/train.py \
-c ./configs/quick_start/ResNet50_vd_multilabel.yaml
```
After training for 10 epochs, the best accuracy over the validation set should be around 0.72.
## Evaluation
```bash
python tools/eval.py \
-c ./configs/quick_start/ResNet50_vd_multilabel.yaml \
-o pretrained_model="./output/ResNet50_vd/best_model/ppcls" \
-o load_static_weights=False
```
The metric of evaluation is based on mAP, which is commonly used in multilabel task to show model perfermance. The mAP over validation set should be around 0.57.
## Prediction
```bash
python tools/infer/infer.py \
-i "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/0199_434752251.jpg" \
--model ResNet50_vd \
--pretrained_model "./output/ResNet50_vd/best_model/ppcls" \
--use_gpu True \
--load_static_weights False \
--multilabel True \
--class_num 33
```
You will get multiple output such as the following:
```
class id: 3, probability: 0.6025
class id: 23, probability: 0.5491
class id: 32, probability: 0.7006
```
\ No newline at end of file
......@@ -6,4 +6,4 @@
image_augmentation/index
distillation/index
multilabel/index
多标签分类
================================
.. toctree::
:maxdepth: 3
multilabel.md
\ No newline at end of file
# 多标签分类quick start
基于[NUS-WIDE-SCENE](https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html)数据集,体验多标签分类的训练、评估、预测的过程,该数据集是NUS-WIDE数据集的一个子集。请事先参考[安装指南](install.md)配置运行环境和克隆PaddleClas代码。
## 一、数据和模型准备
* 进入PaddleClas目录。
```
cd path_to_PaddleClas
```
* 创建并进入`dataset/NUS-WIDE-SCENE`目录,下载并解压NUS-WIDE-SCENE数据集。
```shell
mkdir dataset/NUS-WIDE-SCENE
cd dataset/NUS-WIDE-SCENE
wget https://paddle-imagenet-models-name.bj.bcebos.com/data/NUS-SCENE-dataset.tar
tar -xf NUS-SCENE-dataset.tar
```
* 返回`PaddleClas`根目录
```
cd ../../
```
## 二、环境准备
### 2.1 下载预训练模型
本例展示基于ResNet50_vd模型的多标签分类流程,因此首先下载ResNet50_vd的预训练模型
```bash
mkdir pretrained
cd pretrained
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams
cd ../
```
## 三、模型训练
```shell
export CUDA_VISIBLE_DEVICES=0
python -m paddle.distributed.launch \
--gpus="0" \
tools/train.py \
-c ./configs/quick_start/ResNet50_vd_multilabel.yaml
```
训练10epoch之后,验证集最好的正确率应该在0.72左右。
## 四、模型评估
```bash
python tools/eval.py \
-c ./configs/quick_start/ResNet50_vd_multilabel.yaml \
-o pretrained_model="./output/ResNet50_vd/best_model/ppcls" \
-o load_static_weights=False
```
评估指标采用mAP,验证集的mAP应该在0.57左右。
## 五、模型预测
```bash
python tools/infer/infer.py \
-i "./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/0199_434752251.jpg" \
--model ResNet50_vd \
--pretrained_model "./output/ResNet50_vd/best_model/ppcls" \
--use_gpu True \
--load_static_weights False \
--multilabel True \
--class_num 33
```
得到类似下面的输出:
```
class id: 3, probability: 0.6025
class id: 23, probability: 0.5491
class id: 32, probability: 0.7006
```
\ No newline at end of file
......@@ -197,6 +197,40 @@ class CommonDataset(Dataset):
def __len__(self):
return self.num_samples
class MultiLabelDataset(Dataset):
"""
Define dataset class for multilabel image classification
"""
def __init__(self, params):
self.params = params
self.mode = params.get("mode", "train")
self.full_lines = get_file_list(params)
self.delimiter = params.get("delimiter", "\t")
self.ops = create_operators(params["transforms"])
self.num_samples = len(self.full_lines)
return
def __getitem__(self, idx):
try:
line = self.full_lines[idx]
img_path, label_str = line.split(self.delimiter)
img_path = os.path.join(self.params["data_dir"], img_path)
with open(img_path, "rb") as f:
img = f.read()
labels = label_str.split(',')
labels = [int(i) for i in labels]
return (transform(img, self.ops), np.array(labels).astype("float32"))
except Exception as e:
logger.error("data read failed: {}, exception info: {}".format(line, e))
return self.__getitem__(random.randint(0, len(self)))
def __len__(self):
return self.num_samples
class Reader:
......@@ -229,6 +263,7 @@ class Reader:
self.collate_fn = self.mix_collate_fn
self.places = places
self.multilabel = config.get("multilabel", False)
def mix_collate_fn(self, batch):
batch = transform(batch, self.batch_ops)
......@@ -246,7 +281,10 @@ class Reader:
def __call__(self):
batch_size = int(self.params['batch_size']) // trainers_num
dataset = CommonDataset(self.params)
if self.multilabel:
dataset = MultiLabelDataset(self.params)
else:
dataset = CommonDataset(self.params)
is_train = self.params['mode'] == "train"
batch_sampler = DistributedBatchSampler(
......
......@@ -15,7 +15,7 @@
import paddle
import paddle.nn.functional as F
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss']
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss', 'MultiLabelLoss']
class Loss(object):
......@@ -41,6 +41,17 @@ class Loss(object):
soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
return soft_target
def _binary_crossentropy(self, input, target):
if self._label_smoothing:
target = self._labelsmoothing(target)
cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
else:
cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
avg_cost = paddle.mean(cost)
return avg_cost
def _crossentropy(self, input, target):
if self._label_smoothing:
......@@ -68,6 +79,20 @@ class Loss(object):
def __call__(self, input, target):
pass
class MultiLabelLoss(Loss):
"""
Multilabel loss based binary cross entropy
"""
def __init__(self, class_dim=1000, epsilon=None):
super(MultiLabelLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._binary_crossentropy(input, target)
return cost
class CELoss(Loss):
......
......@@ -15,7 +15,13 @@
from . import logger
from . import misc
from . import model_zoo
from . import metrics
from .save_load import init_model, save_model
from .config import get_config
from .misc import AverageMeter
from .metrics import multi_hot_encode
from .metrics import hamming_distance
from .metrics import accuracy_score
from .metrics import precision_recall_fscore
from .metrics import mean_average_precision
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from sklearn.metrics import hamming_loss
from sklearn.metrics import accuracy_score as accuracy_metric
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import binarize
import numpy as np
__all__ = ["multi_hot_encode", "hamming_distance", "accuracy_score", "precision_recall_fscore", "mean_average_precision"]
def multi_hot_encode(logits, threshold=0.5):
"""
Encode logits to multi-hot by elementwise for multilabel
"""
return binarize(logits, threshold)
def hamming_distance(output, target):
"""
Soft metric based label for multilabel classification
Returns:
The smaller the return value is, the better model is.
"""
return hamming_loss(target, output)
def accuracy_score(output, target, base="sample"):
"""
Hard metric for multilabel classification
Args:
output:
target:
base: ["sample", "label"], default="sample"
if "sample", return metric score based sample,
if "label", return metric score based label.
Returns:
accuracy:
"""
assert base in ["sample", "label"], 'must be one of ["sample", "label"]'
if base == "sample":
accuracy = accuracy_metric(target, output)
elif base == "label":
mcm = multilabel_confusion_matrix(target, output)
tns = mcm[:, 0, 0]
fns = mcm[:, 1, 0]
tps = mcm[:, 1, 1]
fps = mcm[:, 0, 1]
accuracy = (sum(tps) + sum(tns)) / (sum(tps) + sum(tns) + sum(fns) + sum(fps))
return accuracy
def precision_recall_fscore(output, target):
"""
Metric based label for multilabel classification
Returns:
precisions:
recalls:
fscores:
"""
precisions, recalls, fscores, _ = precision_recall_fscore_support(target, output)
return precisions, recalls, fscores
def mean_average_precision(logits, target):
"""
Calculate average precision
Args:
logits: probability from network before sigmoid or softmax
target: ground truth, 0 or 1
"""
if not (isinstance(logits, np.ndarray) and isinstance(target, np.ndarray)):
raise TypeError("logits and target should be np.ndarray.")
aps = []
for i in range(target.shape[1]):
ap = average_precision_score(target[:, i], logits[:, i])
aps.append(ap)
return np.mean(aps)
......@@ -5,3 +5,4 @@ tqdm
PyYAML
visualdl >= 2.0.0b
scipy
scikit-learn==0.23.2
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.nn.functional as F
import argparse
import os
......@@ -24,9 +25,15 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppcls.utils import logger
from ppcls.utils.save_load import init_model
from ppcls.utils.config import get_config
from ppcls.utils import multi_hot_encode
from ppcls.utils import accuracy_score
from ppcls.utils import mean_average_precision
from ppcls.utils import precision_recall_fscore
from ppcls.data import Reader
import program
import numpy as np
def parse_args():
parser = argparse.ArgumentParser("PaddleClas eval script")
......@@ -52,6 +59,7 @@ def main(args, return_dict={}):
# assign place
use_gpu = config.get("use_gpu", True)
place = paddle.set_device('gpu' if use_gpu else 'cpu')
multilabel = config.get("multilabel", False)
trainer_num = paddle.distributed.get_world_size()
use_data_parallel = trainer_num != 1
......@@ -68,12 +76,38 @@ def main(args, return_dict={}):
valid_dataloader = Reader(config, 'valid', places=place)()
net.eval()
with paddle.no_grad():
top1_acc = program.run(valid_dataloader, config, net, None, None, 0,
'valid')
return_dict["top1_acc"] = top1_acc
return top1_acc
if not multilabel:
top1_acc = program.run(valid_dataloader, config, net, None, None, 0,
'valid')
return_dict["top1_acc"] = top1_acc
return top1_acc
else:
all_outs = []
targets = []
for idx, batch in enumerate(valid_dataloader()):
feeds = program.create_feeds(batch, False, config.classes_num, multilabel)
out = net(feeds["image"])
out = F.sigmoid(out)
use_distillation = config.get("use_distillation", False)
if use_distillation:
out = out[1]
all_outs.extend(list(out.numpy()))
targets.extend(list(feeds["label"].numpy()))
all_outs = np.array(all_outs)
targets = np.array(targets)
mAP = mean_average_precision(all_outs, targets)
return_dict["mean average precision"] = mAP
return mAP
if __name__ == '__main__':
args = parse_args()
main(args)
return_dict = {}
main(args, return_dict)
print(return_dict)
......@@ -34,6 +34,7 @@ def main():
args = parse_args()
# assign the place
place = paddle.set_device('gpu' if args.use_gpu else 'cpu')
multilabel = True if args.multilabel else False
net = architectures.__dict__[args.model](class_dim=args.class_num)
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
......@@ -61,17 +62,25 @@ def main():
batch_outputs = net(batch_tensor)
if args.model == "GoogLeNet":
batch_outputs = batch_outputs[0]
batch_outputs = F.softmax(batch_outputs)
if multilabel:
batch_outputs = F.sigmoid(batch_outputs)
else:
batch_outputs = F.softmax(batch_outputs)
batch_outputs = batch_outputs.numpy()
batch_result_list = postprocess(batch_outputs, args.top_k)
batch_result_list = postprocess(batch_outputs, args.top_k, multilabel=multilabel)
for number, result_dict in enumerate(batch_result_list):
filename = img_path_list[number].split("/")[-1]
clas_ids = result_dict["clas_ids"]
scores_str = "[{}]".format(", ".join("{:.2f}".format(
r) for r in result_dict["scores"]))
print("File:{}, Top-{} result: class id(s): {}, score(s): {}".
format(filename, args.top_k, clas_ids, scores_str))
if multilabel:
print("File:{}, multilabel result: ".format(filename))
for id, score in zip(clas_ids, result_dict["scores"]):
print("\tclass id: {}, probability: {:.2f}".format(id, score))
else:
scores_str = "[{}]".format(", ".join("{:.2f}".format(
r) for r in result_dict["scores"]))
print("File:{}, Top-{} result: class id(s): {}, score(s): {}".
format(filename, args.top_k, clas_ids, scores_str))
if args.pre_label_image:
save_prelabel_results(clas_ids[0], img_path_list[number],
......
......@@ -31,6 +31,7 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--multilabel", type=str2bool, default=False)
# params for preprocess
parser.add_argument("--resize_short", type=int, default=256)
......@@ -124,11 +125,14 @@ def preprocess(img, args):
return img
def postprocess(batch_outputs, topk=5):
def postprocess(batch_outputs, topk=5, multilabel=False):
batch_results = []
for probs in batch_outputs:
results = []
index = probs.argsort(axis=0)[-topk:][::-1].astype("int32")
if multilabel:
index = np.where(probs >= 0.5)[0].astype('int32')
else:
index = probs.argsort(axis=0)[-topk:][::-1].astype("int32")
clas_id_list = []
score_list = []
for i in index:
......
......@@ -29,12 +29,16 @@ import paddle.nn.functional as F
from ppcls.optimizer import LearningRateBuilder
from ppcls.optimizer import OptimizerBuilder
from ppcls.modeling import architectures
from ppcls.modeling.loss import MultiLabelLoss
from ppcls.modeling.loss import CELoss
from ppcls.modeling.loss import MixCELoss
from ppcls.modeling.loss import JSDivLoss
from ppcls.modeling.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils import multi_hot_encode
from ppcls.utils import hamming_distance
from ppcls.utils import accuracy_score
def create_model(architecture, classes_num):
......@@ -61,7 +65,8 @@ def create_loss(feeds,
classes_num=1000,
epsilon=None,
use_mix=False,
use_distillation=False):
use_distillation=False,
multilabel=False):
"""
Create a loss for optimization, such as:
1. CrossEnotry loss
......@@ -100,7 +105,10 @@ def create_loss(feeds,
feed_lam = feeds['lam']
return loss(out, feed_y_a, feed_y_b, feed_lam)
else:
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
if not multilabel:
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
else:
loss = MultiLabelLoss(class_dim=classes_num, epsilon=epsilon)
return loss(out, feeds["label"])
......@@ -110,6 +118,7 @@ def create_metric(out,
topk=5,
classes_num=1000,
use_distillation=False,
multilabel=False,
mode="train"):
"""
Create measures of model accuracy, such as top1 and top5
......@@ -135,24 +144,43 @@ def create_metric(out,
softmax_out = F.softmax(out)
fetchs = OrderedDict()
# set top1 to fetchs
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
# set topk to fetchs
k = min(topk, classes_num)
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
metric_names = set()
if not multilabel:
softmax_out = F.softmax(out)
# set top1 to fetchs
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
# set topk to fetchs
k = min(topk, classes_num)
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
metric_names.add("top1")
metric_names.add("top{}".format(k))
fetchs['top1'] = top1
topk_name = "top{}".format(k)
fetchs[topk_name] = topk
else:
out = F.sigmoid(out)
preds = multi_hot_encode(out.numpy())
targets = label.numpy()
ham_dist = to_tensor(hamming_distance(preds, targets))
accuracy = to_tensor(accuracy_score(preds, targets, base="label"))
ham_dist_name = "hamming_distance"
accuracy_name = "multilabel_accuracy"
metric_names.add(ham_dist_name)
metric_names.add(accuracy_name)
fetchs[accuracy_name] = accuracy
fetchs[ham_dist_name] = ham_dist
# multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1:
top1 = paddle.distributed.all_reduce(
top1, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
topk = paddle.distributed.all_reduce(
topk, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
fetchs['top1'] = top1
topk_name = 'top{}'.format(k)
fetchs[topk_name] = topk
for metric_name in metric_names:
fetchs[metric_name] = paddle.distributed.all_reduce(
fetchs[metric_name], op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
return fetchs
......@@ -182,12 +210,14 @@ def create_fetchs(feeds, net, config, mode="train"):
epsilon = config.get('ls_epsilon')
use_mix = config.get('use_mix') and mode == 'train'
use_distillation = config.get('use_distillation')
multilabel = config.get('multilabel', False)
out = net(feeds["image"])
fetchs = OrderedDict()
fetchs['loss'] = create_loss(feeds, out, architecture, classes_num,
epsilon, use_mix, use_distillation)
epsilon, use_mix, use_distillation,
multilabel)
if not use_mix:
metric = create_metric(
out,
......@@ -196,6 +226,7 @@ def create_fetchs(feeds, net, config, mode="train"):
topk,
classes_num,
use_distillation,
multilabel=multilabel,
mode=mode)
fetchs.update(metric)
......@@ -240,7 +271,7 @@ def create_optimizer(config, parameter_list=None):
return opt(lr, parameter_list), lr
def create_feeds(batch, use_mix):
def create_feeds(batch, use_mix, num_classes, multilabel=False):
image = batch[0]
if use_mix:
y_a = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
......@@ -248,7 +279,10 @@ def create_feeds(batch, use_mix):
lam = to_tensor(batch[3].numpy().astype("float32").reshape(-1, 1))
feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam}
else:
label = to_tensor(batch[1].numpy().astype('int64').reshape(-1, 1))
if not multilabel:
label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
else:
label = to_tensor(batch[1].numpy().astype('float32').reshape(-1, num_classes))
feeds = {"image": image, "label": label}
return feeds
......@@ -279,6 +313,8 @@ def run(dataloader,
"""
print_interval = config.get("print_interval", 10)
use_mix = config.get("use_mix", False) and mode == "train"
multilabel = config.get("multilabel", False)
classes_num = config.get("classes_num")
metric_list = [
("loss", AverageMeter(
......@@ -291,13 +327,19 @@ def run(dataloader,
'reader_cost', '.5f', postfix=" s,")),
]
if not use_mix:
topk_name = 'top{}'.format(config.topk)
metric_list.insert(
0, (topk_name, AverageMeter(
topk_name, '.5f', postfix=",")))
metric_list.insert(
0, ("top1", AverageMeter(
"top1", '.5f', postfix=",")))
if not multilabel:
topk_name = 'top{}'.format(config.topk)
metric_list.insert(
0, (topk_name, AverageMeter(
topk_name, '.5f', postfix=",")))
metric_list.insert(
0, ("top1", AverageMeter(
"top1", '.5f', postfix=",")))
else:
metric_list.insert(0, ("multilabel_accuracy", AverageMeter(
"multilabel_accuracy", '.5f', postfix=",")))
metric_list.insert(0, ("hamming_distance", AverageMeter(
"hamming_distance", '.5f', postfix=",")))
metric_list = OrderedDict(metric_list)
......@@ -310,7 +352,7 @@ def run(dataloader,
metric_list['reader_time'].update(time.time() - tic)
batch_size = len(batch[0])
feeds = create_feeds(batch, use_mix)
feeds = create_feeds(batch, use_mix, classes_num, multilabel)
fetchs = create_fetchs(feeds, net, config, mode)
if mode == 'train':
avg_loss = fetchs['loss']
......@@ -387,4 +429,7 @@ def run(dataloader,
# return top1_acc in order to save the best model
if mode == 'valid':
return metric_list['top1'].avg
if multilabel:
return metric_list['multilabel_accuracy'].avg
else:
return metric_list['top1'].avg
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册