diff --git a/ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml b/ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml index a7265b066e1c526fbb63f59993ff68bb4ae09d8a..b230f11cbde78e195355d00a7b042b0d9e6a4026 100644 --- a/ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml +++ b/ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml @@ -49,9 +49,8 @@ Loss: model_name_pairs: - ["Student", "Teacher"] Eval: - - DistillationGTCELoss: + - CELoss: weight: 1.0 - model_names: ["Student"] Optimizer: diff --git a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml index e5b8b716222316c0fca80a69154b0c937e6c52da..000cb9add132c0231d72d47e2947d4397c380d61 100644 --- a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml +++ b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml @@ -88,10 +88,8 @@ Loss: s_shapes: *s_shapes t_shapes: *t_shapes Eval: - - DistillationGTCELoss: + - CELoss: weight: 1.0 - model_names: ["Student"] - Optimizer: name: Momentum diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 6e7fc1a76fe8c3bc4402d9428d372b9c2b50a17b..f4c90a393f5043575c5e49f16fd5b220c881e0fc 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -80,22 +80,17 @@ def classification_eval(engine, epoch_id=0): current_samples = batch_size * paddle.distributed.get_world_size() accum_samples += current_samples + if isinstance(out, dict) and "Student" in out: + out = out["Student"] + if isinstance(out, dict) and "logits" in out: + out = out["logits"] + # gather Tensor when distributed if paddle.distributed.get_world_size() > 1: label_list = [] paddle.distributed.all_gather(label_list, batch[1]) labels = paddle.concat(label_list, 0) - if isinstance(out, dict): - if "Student" in out: - out = out["Student"] - if isinstance(out, dict): - out = out["logits"] - elif "logits" in out: - out = out["logits"] - else: - msg = "Error: Wrong key in out!" - raise Exception(msg) if isinstance(out, list): preds = [] for x in out: diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index d27f1100eef871db48b8da9ab86eba6af8aecee8..44d7b5ac0b33f267f6893d39bd42d27c8bac0573 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -118,8 +118,6 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): if hasattr(model_list[i], optim_scope): optim_model.append(getattr(model_list[i], optim_scope)) - assert len(optim_model) == 1, \ - "Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model)) optim = getattr(optimizer, optim_name)( learning_rate=lr, grad_clip=grad_clip, **optim_cfg)(model_list=optim_model)