@@ -107,20 +124,6 @@ class MixupOperator(BatchOperator):
lams=np.array([lam]*bs,dtype=np.float32)
returnlist(zip(imgs,labels,labels[idx],lams))
def__call__(self,batch):
imgs,labels,bs=self._unpack(batch)
ifnp.random.rand()<self._switch_prob:
returnself._cutmix(imgs,labels,bs)
else:
returnself._mixup(imgs,labels,bs)
classCutmixOperator(BatchOperator):
def__init__(self,**kwargs):
raiseException(
f"\"CutmixOperator\" has been deprecated. Please use MixupOperator with \"cutmix_alpha\" and \"switch_prob\" to enable Cutmix. Refor to doc for details."
)
classFmixOperator(BatchOperator):
""" Fmix operator """
...
...
@@ -139,3 +142,42 @@ class FmixOperator(BatchOperator):
size,self._max_soft,self._reformulate)
imgs=mask*imgs+(1-mask)*imgs[idx]
returnlist(zip(imgs,labels,labels[idx],[lam]*bs))
classOpSampler(object):
""" Sample a operator from """
def__init__(self,**op_dict):
"""Build OpSampler
Raises:
Exception: The parameter \"prob\" of operator(s) are be set error.
"""
iflen(op_dict)<1:
msg=f"ConfigWarning: No operator in \"OpSampler\". \"OpSampler\" has been skipped."
self.ops={}
total_prob=0
forop_nameinop_dict:
param=op_dict[op_name]
if"prob"notinparam:
msg=f"ConfigWarning: Parameter \"prob\" should be set when use operator in \"OpSampler\". The operator \"{op_name}\"'s prob has been set \"0\"."
logger.warning(msg)
prob=param.pop("prob",0)
total_prob+=prob
op=eval(op_name)(**param)
self.ops.update({op:prob})
iftotal_prob>1:
msg=f"ConfigError: The total prob of operators in \"OpSampler\" should be less 1."
logger.error(msg)
raiseException(msg)
# add "None Op" when total_prob < 1, "None Op" do nothing