提交 6e77bd6c 编写于 作者: G gaotingquan 提交者: Wei Shengyu

rm codes for compatibility with old version

上级 f525cea0
......@@ -221,15 +221,6 @@ class DataIterator(object):
def build_dataloader(config, mode):
if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
class_num = config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num})
config["DataLoader"].update({"epochs": config["Global"]["epochs"]})
......
......@@ -412,33 +412,18 @@ class Engine(object):
self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True
# TODO(gaotingquan): to compatible with different versions of Paddle
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "0.0"]:
if self.mode == "train":
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
save_dtype='float32')
elif self.amp_eval:
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if paddle_version not in ["2.3", "2.4", "0.0"]:
msg = "When using AMP, PaddleClas release/2.6 and later version only support PaddlePaddle version >= 2.3.0."
logger.error(msg)
raise Exception(msg)
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
......
......@@ -4,7 +4,7 @@ import paddle
import paddle.nn as nn
from ppcls.utils import logger
from .celoss import CELoss, MixCELoss
from .celoss import CELoss
from .googlenetloss import GoogLeNetLoss
from .centerloss import CenterLoss
from .contrasiveloss import ContrastiveLoss
......
......@@ -66,10 +66,3 @@ class CELoss(nn.Layer):
soft_label=soft_label,
reduction=self.reduction)
return {"CELoss": loss}
class MixCELoss(object):
def __init__(self, *args, **kwargs):
msg = "\"MixCELos\" is deprecated, please use \"CELoss\" instead."
logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册