提交 e7e4f68b 编写于 作者: T Tingquan Gao

Revert "refactor: build_train_func & build_eval_func"

This reverts commit 6bed0f57.
上级 6245b64c
......@@ -88,7 +88,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed)
def build(config, mode, use_dali=False, seed=None):
def build(config, mode, device, use_dali=False, seed=None):
assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
......@@ -167,7 +167,7 @@ def build(config, mode, use_dali=False, seed=None):
if batch_sampler is None:
data_loader = DataLoader(
dataset=dataset,
places=paddle.device.get_device(),
places=device,
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory,
......@@ -179,7 +179,7 @@ def build(config, mode, use_dali=False, seed=None):
else:
data_loader = DataLoader(
dataset=dataset,
places=paddle.device.get_device(),
places=device,
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory,
......@@ -218,7 +218,11 @@ def build_dataloader(engine):
}
if engine.mode == 'train':
train_dataloader = build(
engine.config["DataLoader"], "Train", use_dali, seed=None)
engine.config["DataLoader"],
"Train",
engine.device,
use_dali,
seed=None)
iter_per_epoch = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
if engine.config["Global"].get("iter_per_epoch", None):
......@@ -231,23 +235,33 @@ def build_dataloader(engine):
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build(
engine.config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
engine.config["DataLoader"],
"UnLabelTrain",
engine.device,
use_dali,
seed=None)
if engine.mode == "eval" or (engine.mode == "train" and
engine.config["Global"]["eval_during_train"]):
if engine.config["Global"][
"eval_mode"] in ["classification", "adaface"]:
if engine.eval_mode in ["classification", "adaface"]:
dataloader_dict["Eval"] = build(
engine.config["DataLoader"], "Eval", use_dali, seed=None)
elif engine.config["Global"]["eval_mode"] == "retrieval":
engine.config["DataLoader"],
"Eval",
engine.device,
use_dali,
seed=None)
elif engine.eval_mode == "retrieval":
if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
key = list(engine.config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build(
engine.config["DataLoader"]["Eval"], key, use_dali)
dataloader_dict["GalleryQuery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], key, engine.device,
use_dali)
else:
dataloader_dict["Gallery"] = build(
engine.config["DataLoader"]["Eval"], "Gallery", use_dali)
dataloader_dict["Query"] = build(
engine.config["DataLoader"]["Eval"], "Query", use_dali)
dataloader_dict["Gallery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Gallery",
engine.device, use_dali)
dataloader_dict["Query"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Query",
engine.device, use_dali)
return dataloader_dict
......@@ -39,8 +39,7 @@ from ppcls.utils import save_load
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
from .train import build_train_epoch_func
from .evaluation import build_eval_func
from ppcls.engine import train as train_method
from ppcls.engine.train.utils import type_name
from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead
......@@ -62,11 +61,22 @@ class Engine(object):
self.vdl_writer = self._init_vdl()
# init train_func and eval_func
self.train_epoch_func = build_train_epoch_func(self.config)
self.eval_epoch_func = build_eval_func(self.config)
self.train_mode = self.config["Global"].get("train_mode", None)
if self.train_mode is None:
self.train_epoch_func = train_method.train_epoch
else:
self.train_epoch_func = getattr(train_method,
"train_epoch_" + self.train_mode)
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
assert self.eval_mode in [
"classification", "retrieval", "adaface"
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
# set device
self._init_device()
self.device = self._init_device()
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
......@@ -385,7 +395,7 @@ class Engine(object):
assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
logger.info('train with paddle {} and device {}'.format(
paddle.__version__, device))
paddle.set_device(device)
return paddle.set_device(device)
def _init_pretrained(self):
if self.config["Global"]["pretrained_model"] is not None:
......
......@@ -12,15 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .classification import classification_eval
from .retrieval import retrieval_eval
from .adaface import adaface_eval
def build_eval_func(config):
eval_mode = config["Global"].get("eval_mode", None)
if eval_mode is None:
config["Global"]["eval_mode"] = "classification"
return classification_eval
else:
return getattr(sys.modules[__name__], eval_mode + "_eval")
from ppcls.engine.evaluation.classification import classification_eval
from ppcls.engine.evaluation.retrieval import retrieval_eval
from ppcls.engine.evaluation.adaface import adaface_eval
\ No newline at end of file
......@@ -11,18 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .train_metabin import train_epoch_metabin
from .regular_train_epoch import regular_train_epoch
from .train_fixmatch import train_epoch_fixmatch
from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
from .train_progressive import train_epoch_progressive
def build_train_epoch_func(config):
train_mode = config["Global"].get("train_mode", None)
if train_mode is None:
config["Global"]["train_mode"] = "regular_train"
return regular_train_epoch
else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)
from ppcls.engine.train.train import train_epoch
from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch
from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
from ppcls.engine.train.train_progressive import train_epoch_progressive
from ppcls.engine.train.train_metabin import train_epoch_metabin
......@@ -19,7 +19,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_
from ppcls.utils import profiler
def regular_train_epoch(engine, epoch_id, print_batch_step):
def train_epoch(engine, epoch_id, print_batch_step):
tic = time.time()
if not hasattr(engine, "train_dataloader_iter"):
......
......@@ -16,7 +16,8 @@ from __future__ import absolute_import, division, print_function
from ppcls.data import build_dataloader
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger
from .regular_train_epoch import regular_train_epoch
from .train import train_epoch
def train_epoch_progressive(engine, epoch_id, print_batch_step):
......@@ -68,4 +69,4 @@ def train_epoch_progressive(engine, epoch_id, print_batch_step):
f")")
# 3. Train one epoch as usual at current stage
regular_train_epoch(engine, epoch_id, print_batch_step)
train_epoch(engine, epoch_id, print_batch_step)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册