提交 e7e4f68b 编写于 作者: T Tingquan Gao

Revert "refactor: build_train_func & build_eval_func"

This reverts commit 6bed0f57.
上级 6245b64c
...@@ -88,7 +88,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): ...@@ -88,7 +88,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed) random.seed(worker_seed)
def build(config, mode, use_dali=False, seed=None): def build(config, mode, device, use_dali=False, seed=None):
assert mode in [ assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain" ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
...@@ -167,7 +167,7 @@ def build(config, mode, use_dali=False, seed=None): ...@@ -167,7 +167,7 @@ def build(config, mode, use_dali=False, seed=None):
if batch_sampler is None: if batch_sampler is None:
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
places=paddle.device.get_device(), places=device,
num_workers=num_workers, num_workers=num_workers,
return_list=True, return_list=True,
use_shared_memory=use_shared_memory, use_shared_memory=use_shared_memory,
...@@ -179,7 +179,7 @@ def build(config, mode, use_dali=False, seed=None): ...@@ -179,7 +179,7 @@ def build(config, mode, use_dali=False, seed=None):
else: else:
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
places=paddle.device.get_device(), places=device,
num_workers=num_workers, num_workers=num_workers,
return_list=True, return_list=True,
use_shared_memory=use_shared_memory, use_shared_memory=use_shared_memory,
...@@ -218,7 +218,11 @@ def build_dataloader(engine): ...@@ -218,7 +218,11 @@ def build_dataloader(engine):
} }
if engine.mode == 'train': if engine.mode == 'train':
train_dataloader = build( train_dataloader = build(
engine.config["DataLoader"], "Train", use_dali, seed=None) engine.config["DataLoader"],
"Train",
engine.device,
use_dali,
seed=None)
iter_per_epoch = len(train_dataloader) - 1 if platform.system( iter_per_epoch = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
if engine.config["Global"].get("iter_per_epoch", None): if engine.config["Global"].get("iter_per_epoch", None):
...@@ -231,23 +235,33 @@ def build_dataloader(engine): ...@@ -231,23 +235,33 @@ def build_dataloader(engine):
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None: if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build( dataloader_dict["UnLabelTrain"] = build(
engine.config["DataLoader"], "UnLabelTrain", use_dali, seed=None) engine.config["DataLoader"],
"UnLabelTrain",
engine.device,
use_dali,
seed=None)
if engine.mode == "eval" or (engine.mode == "train" and if engine.mode == "eval" or (engine.mode == "train" and
engine.config["Global"]["eval_during_train"]): engine.config["Global"]["eval_during_train"]):
if engine.config["Global"][ if engine.eval_mode in ["classification", "adaface"]:
"eval_mode"] in ["classification", "adaface"]:
dataloader_dict["Eval"] = build( dataloader_dict["Eval"] = build(
engine.config["DataLoader"], "Eval", use_dali, seed=None) engine.config["DataLoader"],
elif engine.config["Global"]["eval_mode"] == "retrieval": "Eval",
engine.device,
use_dali,
seed=None)
elif engine.eval_mode == "retrieval":
if len(engine.config["DataLoader"]["Eval"].keys()) == 1: if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
key = list(engine.config["DataLoader"]["Eval"].keys())[0] key = list(engine.config["DataLoader"]["Eval"].keys())[0]
dataloader_dict["GalleryQuery"] = build( dataloader_dict["GalleryQuery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], key, use_dali) engine.config["DataLoader"]["Eval"], key, engine.device,
use_dali)
else: else:
dataloader_dict["Gallery"] = build( dataloader_dict["Gallery"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Gallery", use_dali) engine.config["DataLoader"]["Eval"], "Gallery",
dataloader_dict["Query"] = build( engine.device, use_dali)
engine.config["DataLoader"]["Eval"], "Query", use_dali) dataloader_dict["Query"] = build_dataloader(
engine.config["DataLoader"]["Eval"], "Query",
engine.device, use_dali)
return dataloader_dict return dataloader_dict
...@@ -39,8 +39,7 @@ from ppcls.utils import save_load ...@@ -39,8 +39,7 @@ from ppcls.utils import save_load
from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators from ppcls.data import create_operators
from .train import build_train_epoch_func from ppcls.engine import train as train_method
from .evaluation import build_eval_func
from ppcls.engine.train.utils import type_name from ppcls.engine.train.utils import type_name
from ppcls.engine import evaluation from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead from ppcls.arch.gears.identity_head import IdentityHead
...@@ -62,11 +61,22 @@ class Engine(object): ...@@ -62,11 +61,22 @@ class Engine(object):
self.vdl_writer = self._init_vdl() self.vdl_writer = self._init_vdl()
# init train_func and eval_func # init train_func and eval_func
self.train_epoch_func = build_train_epoch_func(self.config) self.train_mode = self.config["Global"].get("train_mode", None)
self.eval_epoch_func = build_eval_func(self.config) if self.train_mode is None:
self.train_epoch_func = train_method.train_epoch
else:
self.train_epoch_func = getattr(train_method,
"train_epoch_" + self.train_mode)
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
assert self.eval_mode in [
"classification", "retrieval", "adaface"
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
# set device # set device
self._init_device() self.device = self._init_device()
# gradient accumulation # gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1) self.update_freq = self.config["Global"].get("update_freq", 1)
...@@ -385,7 +395,7 @@ class Engine(object): ...@@ -385,7 +395,7 @@ class Engine(object):
assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"] assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
logger.info('train with paddle {} and device {}'.format( logger.info('train with paddle {} and device {}'.format(
paddle.__version__, device)) paddle.__version__, device))
paddle.set_device(device) return paddle.set_device(device)
def _init_pretrained(self): def _init_pretrained(self):
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
......
...@@ -12,15 +12,6 @@ ...@@ -12,15 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .classification import classification_eval from ppcls.engine.evaluation.classification import classification_eval
from .retrieval import retrieval_eval from ppcls.engine.evaluation.retrieval import retrieval_eval
from .adaface import adaface_eval from ppcls.engine.evaluation.adaface import adaface_eval
\ No newline at end of file
def build_eval_func(config):
eval_mode = config["Global"].get("eval_mode", None)
if eval_mode is None:
config["Global"]["eval_mode"] = "classification"
return classification_eval
else:
return getattr(sys.modules[__name__], eval_mode + "_eval")
...@@ -11,18 +11,8 @@ ...@@ -11,18 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ppcls.engine.train.train import train_epoch
from .train_metabin import train_epoch_metabin from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch
from .regular_train_epoch import regular_train_epoch from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
from .train_fixmatch import train_epoch_fixmatch from ppcls.engine.train.train_progressive import train_epoch_progressive
from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl from ppcls.engine.train.train_metabin import train_epoch_metabin
from .train_progressive import train_epoch_progressive
def build_train_epoch_func(config):
train_mode = config["Global"].get("train_mode", None)
if train_mode is None:
config["Global"]["train_mode"] = "regular_train"
return regular_train_epoch
else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)
...@@ -19,7 +19,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_ ...@@ -19,7 +19,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_
from ppcls.utils import profiler from ppcls.utils import profiler
def regular_train_epoch(engine, epoch_id, print_batch_step): def train_epoch(engine, epoch_id, print_batch_step):
tic = time.time() tic = time.time()
if not hasattr(engine, "train_dataloader_iter"): if not hasattr(engine, "train_dataloader_iter"):
......
...@@ -16,7 +16,8 @@ from __future__ import absolute_import, division, print_function ...@@ -16,7 +16,8 @@ from __future__ import absolute_import, division, print_function
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.engine.train.utils import type_name from ppcls.engine.train.utils import type_name
from ppcls.utils import logger from ppcls.utils import logger
from .regular_train_epoch import regular_train_epoch
from .train import train_epoch
def train_epoch_progressive(engine, epoch_id, print_batch_step): def train_epoch_progressive(engine, epoch_id, print_batch_step):
...@@ -68,4 +69,4 @@ def train_epoch_progressive(engine, epoch_id, print_batch_step): ...@@ -68,4 +69,4 @@ def train_epoch_progressive(engine, epoch_id, print_batch_step):
f")") f")")
# 3. Train one epoch as usual at current stage # 3. Train one epoch as usual at current stage
regular_train_epoch(engine, epoch_id, print_batch_step) train_epoch(engine, epoch_id, print_batch_step)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册