提交 7f55c6e8 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: deprecate MixCELoss

上级 c30b72c8
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
Eval:
- CELoss:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
Eval:
- CELoss:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -30,7 +30,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -29,7 +29,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import copy
import paddle
import numpy as np
......@@ -36,7 +37,7 @@ from ppcls.data import preprocess
from ppcls.data.preprocess import transform
def create_operators(params):
def create_operators(params, class_num=None):
"""
create operators based on the config
......@@ -50,7 +51,10 @@ def create_operators(params):
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
op = getattr(preprocess, op_name)(**param)
op_func = getattr(preprocess, op_name)
if "class_num" in inspect.getfullargspec(op_func).args:
param.update({"class_num": class_num})
op = op_func(**param)
ops.append(op)
return ops
......@@ -65,6 +69,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader(config, mode, paddle.device.get_device(), seed)
class_num = config.get("class_num", None)
config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name')
......@@ -104,7 +109,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
return [np.stack(slot, axis=0) for slot in slots]
if isinstance(batch_transform, list):
batch_ops = create_operators(batch_transform)
batch_ops = create_operators(batch_transform, class_num)
batch_collate_fn = mix_collate_fn
else:
batch_collate_fn = None
......
......@@ -44,6 +44,14 @@ class BatchOperator(object):
labels.append(item[1])
return np.array(imgs), np.array(labels), bs
def _one_hot(self, targets):
return np.eye(self.class_num, dtype="float32")[targets]
def _mix_target(self, targets0, targets1, lam):
one_hots0 = self._one_hot(targets0)
one_hots1 = self._one_hot(targets1)
return one_hots0 * lam + one_hots1 * (1 - lam)
def __call__(self, batch):
return batch
......@@ -51,7 +59,7 @@ class BatchOperator(object):
class MixupOperator(BatchOperator):
""" Mixup operator """
def __init__(self, alpha: float=1.):
def __init__(self, class_num, alpha: float=1.):
"""Build Mixup operator
Args:
......@@ -64,21 +72,27 @@ class MixupOperator(BatchOperator):
raise Exception(
f"Parameter \"alpha\" of Mixup should be greater than 0. \"alpha\": {alpha}."
)
if not class_num:
msg = "Please set \"Arch.class_num\" in config if use \"MixupOperator\"."
logger.error(Exception(msg))
raise Exception(msg)
self._alpha = alpha
self.class_num = class_num
def __call__(self, batch):
imgs, labels, bs = self._unpack(batch)
idx = np.random.permutation(bs)
lam = np.random.beta(self._alpha, self._alpha)
lams = np.array([lam] * bs, dtype=np.float32)
imgs = lam * imgs + (1 - lam) * imgs[idx]
return list(zip(imgs, labels, labels[idx], lams))
targets = self._mix_target(labels, labels[idx], lam)
return list(zip(imgs, targets))
class CutmixOperator(BatchOperator):
""" Cutmix operator """
def __init__(self, alpha=0.2):
def __init__(self, class_num, alpha=0.2):
"""Build Cutmix operator
Args:
......@@ -91,7 +105,13 @@ class CutmixOperator(BatchOperator):
raise Exception(
f"Parameter \"alpha\" of Cutmix should be greater than 0. \"alpha\": {alpha}."
)
if not class_num:
msg = "Please set \"Arch.class_num\" in config if use \"CutmixOperator\"."
logger.error(Exception(msg))
raise Exception(msg)
self._alpha = alpha
self.class_num = class_num
def _rand_bbox(self, size, lam):
""" _rand_bbox """
......@@ -121,18 +141,29 @@ class CutmixOperator(BatchOperator):
imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[idx, :, bbx1:bbx2, bby1:bby2]
lam = 1 - (float(bbx2 - bbx1) * (bby2 - bby1) /
(imgs.shape[-2] * imgs.shape[-1]))
lams = np.array([lam] * bs, dtype=np.float32)
return list(zip(imgs, labels, labels[idx], lams))
targets = self._mix_target(labels, labels[idx], lam)
return list(zip(imgs, targets))
class FmixOperator(BatchOperator):
""" Fmix operator """
def __init__(self, alpha=1, decay_power=3, max_soft=0., reformulate=False):
def __init__(self,
class_num,
alpha=1,
decay_power=3,
max_soft=0.,
reformulate=False):
if not class_num:
msg = "Please set \"Arch.class_num\" in config if use \"FmixOperator\"."
logger.error(Exception(msg))
raise Exception(msg)
self._alpha = alpha
self._decay_power = decay_power
self._max_soft = max_soft
self._reformulate = reformulate
self.class_num = class_num
def __call__(self, batch):
imgs, labels, bs = self._unpack(batch)
......@@ -141,20 +172,27 @@ class FmixOperator(BatchOperator):
lam, mask = sample_mask(self._alpha, self._decay_power, \
size, self._max_soft, self._reformulate)
imgs = mask * imgs + (1 - mask) * imgs[idx]
return list(zip(imgs, labels, labels[idx], [lam] * bs))
targets = self._mix_target(labels, labels[idx], lam)
return list(zip(imgs, targets))
class OpSampler(object):
""" Sample a operator from """
def __init__(self, **op_dict):
def __init__(self, class_num, **op_dict):
"""Build OpSampler
Raises:
Exception: The parameter \"prob\" of operator(s) are be set error.
"""
if not class_num:
msg = "Please set \"Arch.class_num\" in config if use \"OpSampler\"."
logger.error(Exception(msg))
raise Exception(msg)
if len(op_dict) < 1:
msg = f"ConfigWarning: No operator in \"OpSampler\". \"OpSampler\" has been skipped."
logger.warning(msg)
self.ops = {}
total_prob = 0
......@@ -165,12 +203,13 @@ class OpSampler(object):
logger.warning(msg)
prob = param.pop("prob", 0)
total_prob += prob
param.update({"class_num": class_num})
op = eval(op_name)(**param)
self.ops.update({op: prob})
if total_prob > 1:
msg = f"ConfigError: The total prob of operators in \"OpSampler\" should be less 1."
logger.error(msg)
logger.error(Exception(msg))
raise Exception(msg)
# add "None Op" when total_prob < 1, "None Op" do nothing
......
......@@ -112,6 +112,8 @@ class Engine(object):
}
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
class_num = config["Arch"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num})
# build dataloader
if self.mode == 'train':
self.train_dataloader = build_dataloader(
......
......@@ -34,7 +34,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
]
batch_size = batch[0].shape[0]
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
batch[1] = batch[1].reshape([batch_size, -1])
engine.global_step += 1
# image input
......@@ -47,12 +47,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
else:
out = forward(engine, batch)
# calc loss
if engine.config["DataLoader"]["Train"]["dataset"].get(
"batch_transform_ops", None):
loss_dict = engine.train_loss_func(out, batch[1:])
else:
loss_dict = engine.train_loss_func(out, batch[1])
loss_dict = engine.train_loss_func(out, batch[1])
# step opt and lr
if engine.amp:
......
......@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppcls.utils import logger
class CELoss(nn.Layer):
"""
......@@ -56,19 +60,8 @@ class CELoss(nn.Layer):
return {"CELoss": loss}
class MixCELoss(CELoss):
"""
Cross entropy loss with mix(mixup, cutmix, fixmix)
"""
def __init__(self, epsilon=None):
super().__init__()
self.epsilon = epsilon
def __call__(self, input, batch):
target0, target1, lam = batch
loss0 = super().forward(input, target0)["CELoss"]
loss1 = super().forward(input, target1)["CELoss"]
loss = lam * loss0 + (1.0 - lam) * loss1
loss = paddle.mean(loss)
return {"MixCELoss": loss}
class MixCELoss(object):
def __init__(self, *args, **kwargs):
msg = "\"MixCELos\" is deprecated, please use \"CELoss\" instead."
logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册