提交 6bed0f57 编写于 作者: G gaotingquan 提交者: Wei Shengyu

refactor: build_train_func & build_eval_func

1. rm engine.device and use paddle.device.get_device() instead;
2. mv some code to build_train_func or build_eval_func to simpfy engine.
上级 75a20ba5
...@@ -88,7 +88,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): ...@@ -88,7 +88,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed) random.seed(worker_seed)
def build(config, mode, device, use_dali=False, seed=None): def build(config, mode, use_dali=False, seed=None):
assert mode in [ assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain" ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
...@@ -167,7 +167,7 @@ def build(config, mode, device, use_dali=False, seed=None): ...@@ -167,7 +167,7 @@ def build(config, mode, device, use_dali=False, seed=None):
if batch_sampler is None: if batch_sampler is None:
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
places=device, places=paddle.device.get_device(),
num_workers=num_workers, num_workers=num_workers,
return_list=True, return_list=True,
use_shared_memory=use_shared_memory, use_shared_memory=use_shared_memory,
...@@ -179,7 +179,7 @@ def build(config, mode, device, use_dali=False, seed=None): ...@@ -179,7 +179,7 @@ def build(config, mode, device, use_dali=False, seed=None):
else: else:
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
places=device, places=paddle.device.get_device(),
num_workers=num_workers, num_workers=num_workers,
return_list=True, return_list=True,
use_shared_memory=use_shared_memory, use_shared_memory=use_shared_memory,
...@@ -218,11 +218,7 @@ def build_dataloader(engine): ...@@ -218,11 +218,7 @@ def build_dataloader(engine):
} }
if engine.mode == 'train': if engine.mode == 'train':
train_dataloader = build( train_dataloader = build(
engine.config["DataLoader"], engine.config["DataLoader"], "Train", use_dali, seed=None)
"Train",
engine.device,
use_dali,
seed=None)
iter_per_epoch = len(train_dataloader) - 1 if platform.system( iter_per_epoch = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
if engine.config["Global"].get("iter_per_epoch", None): if engine.config["Global"].get("iter_per_epoch", None):
...@@ -235,33 +231,23 @@ def build_dataloader(engine): ...@@ -235,33 +231,23 @@ def build_dataloader(engine):
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None: if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build( dataloader_dict["UnLabelTrain"] = build(
engine.config["DataLoader"], engine.config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
"UnLabelTrain",
engine.device,
use_dali,
seed=None)
if engine.mode == "eval" or (engine.mode == "train" and if engine.mode == "eval" or (engine.mode == "train" and
engine.config["Global"]["eval_during_train"]): engine.config["Global"]["eval_during_train"]):
if engine.eval_mode in ["classification", "adaface"]: if engine.config["Global"][
"eval_mode"] in ["classification", "adaface"]:
dataloader_dict["Eval"] = build( dataloader_dict["Eval"] = build(
engine.config["DataLoader"], engine.config["DataLoader"], "Eval", use_dali, seed=None)
"Eval", elif engine.config["Global"]["eval_mode"] == "retrieval":
engine.device,
use_dali,
seed=None)
elif engine.eval_mode == "retrieval":
if len(engine.config["DataLoader"]["Eval"].keys()) == 1: if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
key = list(engine.config["DataLoader"]["Eval"].keys())[0] key = list(engine.config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build_dataloader( dataloader_dict["GalleryQuery"] = build(
engine.config["DataLoader"]["Eval"], key, engine.device, engine.config["DataLoader"]["Eval"], key, use_dali)
use_dali)
else: else:
dataloader_dict["Gallery"] = build_dataloader( dataloader_dict["Gallery"] = build(
engine.config["DataLoader"]["Eval"], "Gallery", engine.config["DataLoader"]["Eval"], "Gallery", use_dali)
engine.device, use_dali) dataloader_dict["Query"] = build(
dataloader_dict["Query"] = build_dataloader( engine.config["DataLoader"]["Eval"], "Query", use_dali)
engine.config["DataLoader"]["Eval"], "Query",
engine.device, use_dali)
return dataloader_dict return dataloader_dict
...@@ -39,7 +39,8 @@ from ppcls.utils import save_load ...@@ -39,7 +39,8 @@ from ppcls.utils import save_load
from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators from ppcls.data import create_operators
from ppcls.engine import train as train_method from .train import build_train_epoch_func
from .evaluation import build_eval_func
from ppcls.engine.train.utils import type_name from ppcls.engine.train.utils import type_name
from ppcls.engine import evaluation from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead from ppcls.arch.gears.identity_head import IdentityHead
...@@ -61,22 +62,11 @@ class Engine(object): ...@@ -61,22 +62,11 @@ class Engine(object):
self.vdl_writer = self._init_vdl() self.vdl_writer = self._init_vdl()
# init train_func and eval_func # init train_func and eval_func
self.train_mode = self.config["Global"].get("train_mode", None) self.train_epoch_func = build_train_epoch_func(self.config)
if self.train_mode is None: self.eval_epoch_func = build_eval_func(self.config)
self.train_epoch_func = train_method.train_epoch
else:
self.train_epoch_func = getattr(train_method,
"train_epoch_" + self.train_mode)
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
assert self.eval_mode in [
"classification", "retrieval", "adaface"
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
# set device # set device
self.device = self._init_device() self._init_device()
# gradient accumulation # gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1) self.update_freq = self.config["Global"].get("update_freq", 1)
...@@ -395,7 +385,7 @@ class Engine(object): ...@@ -395,7 +385,7 @@ class Engine(object):
assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"] assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
logger.info('train with paddle {} and device {}'.format( logger.info('train with paddle {} and device {}'.format(
paddle.__version__, device)) paddle.__version__, device))
return paddle.set_device(device) paddle.set_device(device)
def _init_pretrained(self): def _init_pretrained(self):
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
......
...@@ -12,6 +12,15 @@ ...@@ -12,6 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ppcls.engine.evaluation.classification import classification_eval from .classification import classification_eval
from ppcls.engine.evaluation.retrieval import retrieval_eval from .retrieval import retrieval_eval
from ppcls.engine.evaluation.adaface import adaface_eval from .adaface import adaface_eval
\ No newline at end of file
def build_eval_func(config):
eval_mode = config["Global"].get("eval_mode", None)
if eval_mode is None:
config["Global"]["eval_mode"] = "classification"
return classification_eval
else:
return getattr(sys.modules[__name__], eval_mode + "_eval")
...@@ -11,8 +11,18 @@ ...@@ -11,8 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ppcls.engine.train.train import train_epoch
from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch from .train_metabin import train_epoch_metabin
from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl from .regular_train_epoch import regular_train_epoch
from ppcls.engine.train.train_progressive import train_epoch_progressive from .train_fixmatch import train_epoch_fixmatch
from ppcls.engine.train.train_metabin import train_epoch_metabin from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
from .train_progressive import train_epoch_progressive
def build_train_epoch_func(config):
train_mode = config["Global"].get("train_mode", None)
if train_mode is None:
config["Global"]["train_mode"] = "regular_train"
return regular_train_epoch
else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)
...@@ -19,7 +19,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_ ...@@ -19,7 +19,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_
from ppcls.utils import profiler from ppcls.utils import profiler
def train_epoch(engine, epoch_id, print_batch_step): def regular_train_epoch(engine, epoch_id, print_batch_step):
tic = time.time() tic = time.time()
if not hasattr(engine, "train_dataloader_iter"): if not hasattr(engine, "train_dataloader_iter"):
......
...@@ -16,8 +16,7 @@ from __future__ import absolute_import, division, print_function ...@@ -16,8 +16,7 @@ from __future__ import absolute_import, division, print_function
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.engine.train.utils import type_name from ppcls.engine.train.utils import type_name
from ppcls.utils import logger from ppcls.utils import logger
from .regular_train_epoch import regular_train_epoch
from .train import train_epoch
def train_epoch_progressive(engine, epoch_id, print_batch_step): def train_epoch_progressive(engine, epoch_id, print_batch_step):
...@@ -69,4 +68,4 @@ def train_epoch_progressive(engine, epoch_id, print_batch_step): ...@@ -69,4 +68,4 @@ def train_epoch_progressive(engine, epoch_id, print_batch_step):
f")") f")")
# 3. Train one epoch as usual at current stage # 3. Train one epoch as usual at current stage
train_epoch(engine, epoch_id, print_batch_step) regular_train_epoch(engine, epoch_id, print_batch_step)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册