提交 6e77bd6c 编写于 作者: G gaotingquan 提交者: Wei Shengyu

rm codes for compatibility with old version

上级 f525cea0
...@@ -221,15 +221,6 @@ class DataIterator(object): ...@@ -221,15 +221,6 @@ class DataIterator(object):
def build_dataloader(config, mode): def build_dataloader(config, mode):
if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
config["Arch"]["class_num"] = global_class_num
msg = f"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to {global_class_num}."
else:
msg = "The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger.warning(msg)
class_num = config["Arch"].get("class_num", None) class_num = config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num}) config["DataLoader"].update({"class_num": class_num})
config["DataLoader"].update({"epochs": config["Global"]["epochs"]}) config["DataLoader"].update({"epochs": config["Global"]["epochs"]})
......
...@@ -412,28 +412,13 @@ class Engine(object): ...@@ -412,28 +412,13 @@ class Engine(object):
self.config["AMP"]["use_fp16_test"] = True self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True self.amp_eval = True
# TODO(gaotingquan): to compatible with different versions of Paddle
paddle_version = paddle.__version__[:3] paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop # paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "0.0"]: if paddle_version not in ["2.3", "2.4", "0.0"]:
if self.mode == "train": msg = "When using AMP, PaddleClas release/2.6 and later version only support PaddlePaddle version >= 2.3.0."
self.model, self.optimizer = paddle.amp.decorate( logger.error(msg)
models=self.model, raise Exception(msg)
optimizers=self.optimizer,
level=self.amp_level,
save_dtype='float32')
elif self.amp_eval:
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
if self.mode == "train" or self.amp_eval: if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate( self.model = paddle.amp.decorate(
models=self.model, models=self.model,
......
...@@ -4,7 +4,7 @@ import paddle ...@@ -4,7 +4,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
from ppcls.utils import logger from ppcls.utils import logger
from .celoss import CELoss, MixCELoss from .celoss import CELoss
from .googlenetloss import GoogLeNetLoss from .googlenetloss import GoogLeNetLoss
from .centerloss import CenterLoss from .centerloss import CenterLoss
from .contrasiveloss import ContrastiveLoss from .contrasiveloss import ContrastiveLoss
......
...@@ -66,10 +66,3 @@ class CELoss(nn.Layer): ...@@ -66,10 +66,3 @@ class CELoss(nn.Layer):
soft_label=soft_label, soft_label=soft_label,
reduction=self.reduction) reduction=self.reduction)
return {"CELoss": loss} return {"CELoss": loss}
class MixCELoss(object):
def __init__(self, *args, **kwargs):
msg = "\"MixCELos\" is deprecated, please use \"CELoss\" instead."
logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册