提交 ba2dd01a 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: deprecate MixCELoss

上级 69d9a477
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
Eval: Eval:
- CELoss: - CELoss:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
Eval: Eval:
- CELoss: - CELoss:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -24,7 +24,7 @@ Arch: ...@@ -24,7 +24,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -26,7 +26,7 @@ Arch: ...@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -26,7 +26,7 @@ Arch: ...@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -26,7 +26,7 @@ Arch: ...@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -26,7 +26,7 @@ Arch: ...@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -26,7 +26,7 @@ Arch: ...@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -26,7 +26,7 @@ Arch: ...@@ -26,7 +26,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -30,7 +30,7 @@ Arch: ...@@ -30,7 +30,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -29,7 +29,7 @@ Arch: ...@@ -29,7 +29,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- MixCELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import copy import copy
import paddle import paddle
import numpy as np import numpy as np
...@@ -36,7 +37,7 @@ from ppcls.data import preprocess ...@@ -36,7 +37,7 @@ from ppcls.data import preprocess
from ppcls.data.preprocess import transform from ppcls.data.preprocess import transform
def create_operators(params): def create_operators(params, class_num=None):
""" """
create operators based on the config create operators based on the config
...@@ -50,7 +51,10 @@ def create_operators(params): ...@@ -50,7 +51,10 @@ def create_operators(params):
dict) and len(operator) == 1, "yaml format error" dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0] op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name] param = {} if operator[op_name] is None else operator[op_name]
op = getattr(preprocess, op_name)(**param) op_func = getattr(preprocess, op_name)
if "class_num" in inspect.getfullargspec(op_func).args:
param.update({"class_num": class_num})
op = op_func(**param)
ops.append(op) ops.append(op)
return ops return ops
...@@ -65,6 +69,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): ...@@ -65,6 +69,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
from ppcls.data.dataloader.dali import dali_dataloader from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader(config, mode, paddle.device.get_device(), seed) return dali_dataloader(config, mode, paddle.device.get_device(), seed)
class_num = config.get("class_num", None)
config_dataset = config[mode]['dataset'] config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset) config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name') dataset_name = config_dataset.pop('name')
...@@ -104,7 +109,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): ...@@ -104,7 +109,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
return [np.stack(slot, axis=0) for slot in slots] return [np.stack(slot, axis=0) for slot in slots]
if isinstance(batch_transform, list): if isinstance(batch_transform, list):
batch_ops = create_operators(batch_transform) batch_ops = create_operators(batch_transform, class_num)
batch_collate_fn = mix_collate_fn batch_collate_fn = mix_collate_fn
else: else:
batch_collate_fn = None batch_collate_fn = None
......
...@@ -44,6 +44,14 @@ class BatchOperator(object): ...@@ -44,6 +44,14 @@ class BatchOperator(object):
labels.append(item[1]) labels.append(item[1])
return np.array(imgs), np.array(labels), bs return np.array(imgs), np.array(labels), bs
def _one_hot(self, targets):
return np.eye(self.class_num, dtype="float32")[targets]
def _mix_target(self, targets0, targets1, lam):
one_hots0 = self._one_hot(targets0)
one_hots1 = self._one_hot(targets1)
return one_hots0 * lam + one_hots1 * (1 - lam)
def __call__(self, batch): def __call__(self, batch):
return batch return batch
...@@ -51,7 +59,7 @@ class BatchOperator(object): ...@@ -51,7 +59,7 @@ class BatchOperator(object):
class MixupOperator(BatchOperator): class MixupOperator(BatchOperator):
""" Mixup operator """ """ Mixup operator """
def __init__(self, alpha: float=1.): def __init__(self, class_num, alpha: float=1.):
"""Build Mixup operator """Build Mixup operator
Args: Args:
...@@ -64,21 +72,27 @@ class MixupOperator(BatchOperator): ...@@ -64,21 +72,27 @@ class MixupOperator(BatchOperator):
raise Exception( raise Exception(
f"Parameter \"alpha\" of Mixup should be greater than 0. \"alpha\": {alpha}." f"Parameter \"alpha\" of Mixup should be greater than 0. \"alpha\": {alpha}."
) )
if not class_num:
msg = "Please set \"Arch.class_num\" in config if use \"MixupOperator\"."
logger.error(Exception(msg))
raise Exception(msg)
self._alpha = alpha self._alpha = alpha
self.class_num = class_num
def __call__(self, batch): def __call__(self, batch):
imgs, labels, bs = self._unpack(batch) imgs, labels, bs = self._unpack(batch)
idx = np.random.permutation(bs) idx = np.random.permutation(bs)
lam = np.random.beta(self._alpha, self._alpha) lam = np.random.beta(self._alpha, self._alpha)
lams = np.array([lam] * bs, dtype=np.float32)
imgs = lam * imgs + (1 - lam) * imgs[idx] imgs = lam * imgs + (1 - lam) * imgs[idx]
return list(zip(imgs, labels, labels[idx], lams)) targets = self._mix_target(labels, labels[idx], lam)
return list(zip(imgs, targets))
class CutmixOperator(BatchOperator): class CutmixOperator(BatchOperator):
""" Cutmix operator """ """ Cutmix operator """
def __init__(self, alpha=0.2): def __init__(self, class_num, alpha=0.2):
"""Build Cutmix operator """Build Cutmix operator
Args: Args:
...@@ -91,7 +105,13 @@ class CutmixOperator(BatchOperator): ...@@ -91,7 +105,13 @@ class CutmixOperator(BatchOperator):
raise Exception( raise Exception(
f"Parameter \"alpha\" of Cutmix should be greater than 0. \"alpha\": {alpha}." f"Parameter \"alpha\" of Cutmix should be greater than 0. \"alpha\": {alpha}."
) )
if not class_num:
msg = "Please set \"Arch.class_num\" in config if use \"CutmixOperator\"."
logger.error(Exception(msg))
raise Exception(msg)
self._alpha = alpha self._alpha = alpha
self.class_num = class_num
def _rand_bbox(self, size, lam): def _rand_bbox(self, size, lam):
""" _rand_bbox """ """ _rand_bbox """
...@@ -121,18 +141,29 @@ class CutmixOperator(BatchOperator): ...@@ -121,18 +141,29 @@ class CutmixOperator(BatchOperator):
imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[idx, :, bbx1:bbx2, bby1:bby2] imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[idx, :, bbx1:bbx2, bby1:bby2]
lam = 1 - (float(bbx2 - bbx1) * (bby2 - bby1) / lam = 1 - (float(bbx2 - bbx1) * (bby2 - bby1) /
(imgs.shape[-2] * imgs.shape[-1])) (imgs.shape[-2] * imgs.shape[-1]))
lams = np.array([lam] * bs, dtype=np.float32) targets = self._mix_target(labels, labels[idx], lam)
return list(zip(imgs, labels, labels[idx], lams)) return list(zip(imgs, targets))
class FmixOperator(BatchOperator): class FmixOperator(BatchOperator):
""" Fmix operator """ """ Fmix operator """
def __init__(self, alpha=1, decay_power=3, max_soft=0., reformulate=False): def __init__(self,
class_num,
alpha=1,
decay_power=3,
max_soft=0.,
reformulate=False):
if not class_num:
msg = "Please set \"Arch.class_num\" in config if use \"FmixOperator\"."
logger.error(Exception(msg))
raise Exception(msg)
self._alpha = alpha self._alpha = alpha
self._decay_power = decay_power self._decay_power = decay_power
self._max_soft = max_soft self._max_soft = max_soft
self._reformulate = reformulate self._reformulate = reformulate
self.class_num = class_num
def __call__(self, batch): def __call__(self, batch):
imgs, labels, bs = self._unpack(batch) imgs, labels, bs = self._unpack(batch)
...@@ -141,20 +172,27 @@ class FmixOperator(BatchOperator): ...@@ -141,20 +172,27 @@ class FmixOperator(BatchOperator):
lam, mask = sample_mask(self._alpha, self._decay_power, \ lam, mask = sample_mask(self._alpha, self._decay_power, \
size, self._max_soft, self._reformulate) size, self._max_soft, self._reformulate)
imgs = mask * imgs + (1 - mask) * imgs[idx] imgs = mask * imgs + (1 - mask) * imgs[idx]
return list(zip(imgs, labels, labels[idx], [lam] * bs)) targets = self._mix_target(labels, labels[idx], lam)
return list(zip(imgs, targets))
class OpSampler(object): class OpSampler(object):
""" Sample a operator from """ """ Sample a operator from """
def __init__(self, **op_dict): def __init__(self, class_num, **op_dict):
"""Build OpSampler """Build OpSampler
Raises: Raises:
Exception: The parameter \"prob\" of operator(s) are be set error. Exception: The parameter \"prob\" of operator(s) are be set error.
""" """
if not class_num:
msg = "Please set \"Arch.class_num\" in config if use \"OpSampler\"."
logger.error(Exception(msg))
raise Exception(msg)
if len(op_dict) < 1: if len(op_dict) < 1:
msg = f"ConfigWarning: No operator in \"OpSampler\". \"OpSampler\" has been skipped." msg = f"ConfigWarning: No operator in \"OpSampler\". \"OpSampler\" has been skipped."
logger.warning(msg)
self.ops = {} self.ops = {}
total_prob = 0 total_prob = 0
...@@ -165,12 +203,13 @@ class OpSampler(object): ...@@ -165,12 +203,13 @@ class OpSampler(object):
logger.warning(msg) logger.warning(msg)
prob = param.pop("prob", 0) prob = param.pop("prob", 0)
total_prob += prob total_prob += prob
param.update({"class_num": class_num})
op = eval(op_name)(**param) op = eval(op_name)(**param)
self.ops.update({op: prob}) self.ops.update({op: prob})
if total_prob > 1: if total_prob > 1:
msg = f"ConfigError: The total prob of operators in \"OpSampler\" should be less 1." msg = f"ConfigError: The total prob of operators in \"OpSampler\" should be less 1."
logger.error(msg) logger.error(Exception(msg))
raise Exception(msg) raise Exception(msg)
# add "None Op" when total_prob < 1, "None Op" do nothing # add "None Op" when total_prob < 1, "None Op" do nothing
......
...@@ -112,6 +112,8 @@ class Engine(object): ...@@ -112,6 +112,8 @@ class Engine(object):
} }
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
class_num = config["Arch"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num})
# build dataloader # build dataloader
if self.mode == 'train': if self.mode == 'train':
self.train_dataloader = build_dataloader( self.train_dataloader = build_dataloader(
......
...@@ -36,25 +36,19 @@ def train_epoch(engine, epoch_id, print_batch_step): ...@@ -36,25 +36,19 @@ def train_epoch(engine, epoch_id, print_batch_step):
] ]
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
if not engine.config["Global"].get("use_multilabel", False): if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([batch_size, -1])
engine.global_step += 1 engine.global_step += 1
if engine.config["DataLoader"]["Train"]["dataset"].get(
"batch_transform_ops", None):
gt_input = batch[1:]
else:
gt_input = batch[1]
# image input # image input
if engine.amp: if engine.amp:
with paddle.amp.auto_cast(custom_black_list={ with paddle.amp.auto_cast(custom_black_list={
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
}): }):
out = forward(engine, batch) out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, gt_input)
else: else:
out = forward(engine, batch) out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, gt_input)
loss_dict = engine.train_loss_func(out, batch[1])
# step opt and lr # step opt and lr
if engine.amp: if engine.amp:
......
...@@ -12,10 +12,14 @@ ...@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ppcls.utils import logger
class CELoss(nn.Layer): class CELoss(nn.Layer):
""" """
...@@ -56,19 +60,8 @@ class CELoss(nn.Layer): ...@@ -56,19 +60,8 @@ class CELoss(nn.Layer):
return {"CELoss": loss} return {"CELoss": loss}
class MixCELoss(CELoss): class MixCELoss(object):
""" def __init__(self, *args, **kwargs):
Cross entropy loss with mix(mixup, cutmix, fixmix) msg = "\"MixCELos\" is deprecated, please use \"CELoss\" instead."
""" logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)
def __init__(self, epsilon=None):
super().__init__()
self.epsilon = epsilon
def __call__(self, input, batch):
target0, target1, lam = batch
loss0 = super().forward(input, target0)["CELoss"]
loss1 = super().forward(input, target1)["CELoss"]
loss = lam * loss0 + (1.0 - lam) * loss1
loss = paddle.mean(loss)
return {"MixCELoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册