未验证 提交 38813434 编写于 作者: L littletomatodonkey 提交者: GitHub

add support for train and eval (#752)

* add support for train and eval

* rm unsed code

* add support for metric save and load ckp
上级 decdb51b
# global configs
Global:
pretrained_model: ""
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
class_num: 1000
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 90
epochs: 120
print_batch_step: 10
use_visualdl: False
image_shape: [3, 224, 224]
......
......@@ -25,9 +25,7 @@ import paddle
import paddle.nn as nn
import paddle.distributed as dist
from ppcls.utils import config
from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.data import build_dataloader
......@@ -35,16 +33,15 @@ from ppcls.arch import build_model
from ppcls.arch.loss_metrics import build_loss
from ppcls.arch.loss_metrics import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.utils.save_load import init_model
from ppcls.utils import save_load
class Trainer(object):
def __init__(self, mode="train"):
args = config.parse_args()
self.config = config.get_config(
args.config, overrides=args.override, show=True)
def __init__(self, config, mode="train"):
self.mode = mode
self.config = config
self.output_dir = self.config['Global']['output_dir']
# set device
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
......@@ -56,6 +53,10 @@ class Trainer(object):
dist.init_parallel_env()
self.model = build_model(self.config["Arch"])
if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(self.model,
self.config["Global"]["pretrained_model"])
if self.config["Global"]["distributed"]:
self.model = paddle.DataParallel(self.model)
......@@ -122,7 +123,15 @@ class Trainer(object):
# global iter counter
global_step = 0
for epoch_id in range(1, self.config["Global"]["epochs"] + 1):
if self.config["Global"]["checkpoints"] is not None:
metric_info = init_model(self.config["Global"], self.model,
optimizer)
if metric_info is not None:
best_metric.update(metric_info)
for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1):
acc = 0.0
self.model.train()
for iter_id, batch in enumerate(train_dataloader()):
batch_size = batch[0].shape[0]
......@@ -176,12 +185,13 @@ class Trainer(object):
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_during_train"] == 0:
acc = self.eval(epoch_id)
if acc >= best_metric["metric"]:
if acc > best_metric["metric"]:
best_metric["metric"] = acc
best_metric["epoch"] = epoch_id
save_load.save_model(
self.model,
optimizer,
best_metric,
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="best_model")
......@@ -190,7 +200,8 @@ class Trainer(object):
if epoch_id % save_interval == 0:
save_load.save_model(
self.model,
optimizer,
optimizer, {"metric": acc,
"epoch": epoch_id},
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="ppcls_epoch_{}".format(epoch_id))
......@@ -266,12 +277,3 @@ class Trainer(object):
return -1
# return 1st metric in the dict
return output_info[metric_key].avg
def main():
trainer = Trainer()
trainer.train()
if __name__ == "__main__":
main()
......@@ -71,11 +71,17 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
return
def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld, load_static_weights=False):
def load_dygraph_pretrain_from_url(model,
pretrained_url,
use_ssld,
load_static_weights=False):
if use_ssld:
pretrained_url = pretrained_url.replace("_pretrained", "_ssld_pretrained")
local_weight_path = get_weights_path_from_url(pretrained_url).replace(".pdparams", "")
load_dygraph_pretrain(model, path=local_weight_path, load_static_weights=load_static_weights)
pretrained_url = pretrained_url.replace("_pretrained",
"_ssld_pretrained")
local_weight_path = get_weights_path_from_url(pretrained_url).replace(
".pdparams", "")
load_dygraph_pretrain(
model, path=local_weight_path, load_static_weights=load_static_weights)
return
......@@ -121,10 +127,11 @@ def init_model(config, net, optimizer=None):
"Given dir {}.pdopt not exist.".format(checkpoints)
para_dict = paddle.load(checkpoints + ".pdparams")
opti_dict = paddle.load(checkpoints + ".pdopt")
metric_dict = paddle.load(checkpoints + ".pdstates")
net.set_dict(para_dict)
optimizer.set_state_dict(opti_dict)
logger.info("Finish load checkpoints from {}".format(checkpoints))
return
return metric_dict
pretrained_model = config.get('pretrained_model')
load_static_weights = config.get('load_static_weights', False)
......@@ -155,7 +162,12 @@ def _save_student_model(net, model_prefix):
student_model_prefix))
def save_model(net, optimizer, model_path, model_name="", prefix='ppcls'):
def save_model(net,
optimizer,
metric_info,
model_path,
model_name="",
prefix='ppcls'):
"""
save model to the target path
"""
......@@ -169,4 +181,5 @@ def save_model(net, optimizer, model_path, model_name="", prefix='ppcls'):
paddle.save(net.state_dict(), model_prefix + ".pdparams")
paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
paddle.save(metric_info, model_prefix + ".pdstates")
logger.info("Already save model in {}".format(model_path))
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -12,105 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
import argparse
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppcls.utils import logger
from ppcls.utils.save_load import init_model
from ppcls.utils.config import get_config
from ppcls.utils import multi_hot_encode
from ppcls.utils import accuracy_score
from ppcls.utils import mean_average_precision
from ppcls.utils import precision_recall_fscore
from ppcls.data import Reader
import program
import numpy as np
def parse_args():
parser = argparse.ArgumentParser("PaddleClas eval script")
parser.add_argument(
'-c',
'--config',
type=str,
default='./configs/eval.yaml',
help='config file path')
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
args = parser.parse_args()
return args
def main(args, return_dict={}):
config = get_config(args.config, overrides=args.override, show=True)
config.mode = "valid"
# assign place
use_gpu = config.get("use_gpu", True)
place = paddle.set_device('gpu' if use_gpu else 'cpu')
multilabel = config.get("multilabel", False)
trainer_num = paddle.distributed.get_world_size()
use_data_parallel = trainer_num != 1
config["use_data_parallel"] = use_data_parallel
if config["use_data_parallel"]:
paddle.distributed.init_parallel_env()
net = program.create_model(config.ARCHITECTURE, config.classes_num)
init_model(config, net, optimizer=None)
valid_dataloader = Reader(config, 'valid', places=place)()
if len(valid_dataloader) <= 0:
logger.error(
"valid dataloader is empty, please check your data config again!")
sys.exit(-1)
net.eval()
with paddle.no_grad():
if not multilabel:
top1_acc = program.run(valid_dataloader, config, net, None, None,
0, 'valid')
return_dict["top1_acc"] = top1_acc
return top1_acc
else:
all_outs = []
targets = []
for _, batch in enumerate(valid_dataloader()):
feeds = program.create_feeds(batch, False, config.classes_num,
multilabel)
out = net(feeds["image"])
out = F.sigmoid(out)
use_distillation = config.get("use_distillation", False)
if use_distillation:
out = out[1]
all_outs.extend(list(out.numpy()))
targets.extend(list(feeds["label"].numpy()))
all_outs = np.array(all_outs)
targets = np.array(targets)
mAP = mean_average_precision(all_outs, targets)
return_dict["mean average precision"] = mAP
return mAP
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
from ppcls.utils import config
from ppcls.engine.trainer import Trainer
if __name__ == '__main__':
args = parse_args()
return_dict = {}
main(args, return_dict)
print(return_dict)
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
trainer = Trainer(config, mode="eval")
trainer.eval()
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -15,144 +15,16 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
import paddle
from ppcls.data import Reader
from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model
from ppcls.utils import logger
import program
def parse_args():
parser = argparse.ArgumentParser("PaddleClas train script")
parser.add_argument(
'-c',
'--config',
type=str,
default='configs/ResNet/ResNet50.yaml',
help='config file path')
parser.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
args = parser.parse_args()
return args
def main(args):
paddle.seed(12345)
config = get_config(args.config, overrides=args.override, show=True)
# assign the place
use_gpu = config.get("use_gpu", True)
use_xpu = config.get("use_xpu", False)
assert (
use_gpu and use_xpu
) is not True, "gpu and xpu can not be true in the same time in static mode!"
if use_gpu:
place = paddle.set_device('gpu')
elif use_xpu:
place = paddle.set_device('xpu')
else:
place = paddle.set_device('cpu')
trainer_num = paddle.distributed.get_world_size()
use_data_parallel = trainer_num != 1
config["use_data_parallel"] = use_data_parallel
if config["use_data_parallel"]:
paddle.distributed.init_parallel_env()
net = program.create_model(config.ARCHITECTURE, config.classes_num)
optimizer, lr_scheduler = program.create_optimizer(
config, parameter_list=net.parameters())
dp_net = net
if config["use_data_parallel"]:
find_unused_parameters = config.get("find_unused_parameters", False)
dp_net = paddle.DataParallel(
net, find_unused_parameters=find_unused_parameters)
# load model from checkpoint or pretrained model
init_model(config, net, optimizer)
train_dataloader = Reader(config, 'train', places=place)()
if len(train_dataloader) <= 0:
logger.error(
"train dataloader is empty, please check your data config again!")
sys.exit(-1)
if config.validate:
valid_dataloader = Reader(config, 'valid', places=place)()
if len(valid_dataloader) <= 0:
logger.error(
"valid dataloader is empty, please check your data config again!"
)
sys.exit(-1)
last_epoch_id = config.get("last_epoch", -1)
best_top1_acc = 0.0 # best top1 acc record
best_top1_epoch = last_epoch_id
vdl_writer_path = config.get("vdl_dir", None)
vdl_writer = None
if vdl_writer_path:
from visualdl import LogWriter
vdl_writer = LogWriter(vdl_writer_path)
# Ensure that the vdl log file can be closed normally
try:
for epoch_id in range(last_epoch_id + 1, config.epochs):
net.train()
# 1. train with train dataset
program.run(train_dataloader, config, dp_net, optimizer,
lr_scheduler, epoch_id, 'train', vdl_writer,
args.profiler_options)
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
with paddle.no_grad():
top1_acc = program.run(valid_dataloader, config, net, None,
None, epoch_id, 'valid', vdl_writer)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
best_top1_epoch = epoch_id
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, "best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info(message)
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, epoch_id)
except Exception as e:
logger.error(e)
finally:
vdl_writer.close() if vdl_writer else None
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
from ppcls.utils import config
from ppcls.engine.trainer import Trainer
if __name__ == '__main__':
args = parse_args()
main(args)
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
trainer = Trainer(config, mode="train")
trainer.train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册