提交 5bf74beb 编写于 作者: D dongshuilong

update according comments

上级 15f6f581
......@@ -13,21 +13,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
import time
import os
import platform
import datetime
import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist
from visualdl import LogWriter
from paddle import nn
from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
......@@ -67,7 +59,8 @@ class Engine(object):
print_config(config)
# init train_func and eval_func
assert self.eval_mode in ["classification", "retrieval"], logger.error("Invalid eval mode: {}".format(self.eval_mode))
assert self.eval_mode in ["classification", "retrieval"], logger.error(
"Invalid eval mode: {}".format(self.eval_mode))
self.train_epoch_func = train_epoch
self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
......
......@@ -14,14 +14,10 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import platform
import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
......@@ -61,14 +57,14 @@ def classification_eval(evaler, epoch_id=0):
out = evaler.model(batch[0])
# calc loss
if evaler.eval_loss_func is not None:
loss_dict = evaler.eval_loss_func(out, batch[-1])
loss_dict = evaler.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
# calc metric
if evaler.eval_metric_func is not None:
metric_dict = evaler.eval_metric_func(out, batch[-1])
metric_dict = evaler.eval_metric_func(out, batch[1])
if paddle.distributed.get_world_size() > 1:
for key in metric_dict:
paddle.distributed.all_reduce(
......
......@@ -14,15 +14,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import platform
import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
......
......@@ -13,19 +13,8 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function
import datetime
import os
import platform
import sys
import time
import numpy as np
import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
from ppcls.utils import logger
from ppcls.utils.misc import AverageMeter
from ppcls.engine.train.utils import update_loss, update_metric, log_info
......@@ -88,6 +77,7 @@ def train_epoch(trainer, epoch_id, print_batch_step):
log_info(trainer, batch_size, epoch_id, iter_id)
tic = time.time()
def forward(trainer, batch):
if trainer.eval_mode == "classification":
return trainer.model(batch[0])
......
......@@ -14,16 +14,6 @@
from __future__ import absolute_import, division, print_function
import datetime
import os
import platform
import sys
import time
import numpy as np
import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
from ppcls.utils import logger
from ppcls.utils.misc import AverageMeter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册