提交 5bf74beb 编写于 作者: D dongshuilong

update according comments

上级 15f6f581
...@@ -13,21 +13,13 @@ ...@@ -13,21 +13,13 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
import time import os
import platform import platform
import datetime
import argparse
import paddle import paddle
import paddle.nn as nn
import paddle.distributed as dist import paddle.distributed as dist
from visualdl import LogWriter from visualdl import LogWriter
from paddle import nn
from ppcls.utils.check import check_gpu from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter
...@@ -67,7 +59,8 @@ class Engine(object): ...@@ -67,7 +59,8 @@ class Engine(object):
print_config(config) print_config(config)
# init train_func and eval_func # init train_func and eval_func
assert self.eval_mode in ["classification", "retrieval"], logger.error("Invalid eval mode: {}".format(self.eval_mode)) assert self.eval_mode in ["classification", "retrieval"], logger.error(
"Invalid eval mode: {}".format(self.eval_mode))
self.train_epoch_func = train_epoch self.train_epoch_func = train_epoch
self.eval_func = getattr(evaluation, self.eval_mode + "_eval") self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
......
...@@ -14,14 +14,10 @@ ...@@ -14,14 +14,10 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import sys
import time import time
import platform import platform
import paddle import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
from ppcls.utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
...@@ -61,14 +57,14 @@ def classification_eval(evaler, epoch_id=0): ...@@ -61,14 +57,14 @@ def classification_eval(evaler, epoch_id=0):
out = evaler.model(batch[0]) out = evaler.model(batch[0])
# calc loss # calc loss
if evaler.eval_loss_func is not None: if evaler.eval_loss_func is not None:
loss_dict = evaler.eval_loss_func(out, batch[-1]) loss_dict = evaler.eval_loss_func(out, batch[1])
for key in loss_dict: for key in loss_dict:
if key not in output_info: if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f') output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size) output_info[key].update(loss_dict[key].numpy()[0], batch_size)
# calc metric # calc metric
if evaler.eval_metric_func is not None: if evaler.eval_metric_func is not None:
metric_dict = evaler.eval_metric_func(out, batch[-1]) metric_dict = evaler.eval_metric_func(out, batch[1])
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
for key in metric_dict: for key in metric_dict:
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
......
...@@ -14,15 +14,9 @@ ...@@ -14,15 +14,9 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import sys
import time
import platform import platform
import paddle import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
......
...@@ -13,19 +13,8 @@ ...@@ -13,19 +13,8 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import datetime
import os
import platform
import sys
import time import time
import numpy as np
import paddle import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
from ppcls.utils import logger
from ppcls.utils.misc import AverageMeter
from ppcls.engine.train.utils import update_loss, update_metric, log_info from ppcls.engine.train.utils import update_loss, update_metric, log_info
...@@ -88,6 +77,7 @@ def train_epoch(trainer, epoch_id, print_batch_step): ...@@ -88,6 +77,7 @@ def train_epoch(trainer, epoch_id, print_batch_step):
log_info(trainer, batch_size, epoch_id, iter_id) log_info(trainer, batch_size, epoch_id, iter_id)
tic = time.time() tic = time.time()
def forward(trainer, batch): def forward(trainer, batch):
if trainer.eval_mode == "classification": if trainer.eval_mode == "classification":
return trainer.model(batch[0]) return trainer.model(batch[0])
......
...@@ -14,16 +14,6 @@ ...@@ -14,16 +14,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import datetime import datetime
import os
import platform
import sys
import time
import numpy as np
import paddle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../')))
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册