提交 75a20ba5 编写于 作者: G gaotingquan 提交者: Wei Shengyu

refactor: add ClassModel to unify model forward interface

上级 376d83d4
......@@ -12,14 +12,14 @@
#See the License for the specific language governing permissions and
#limitations under the License.
import sys
import copy
import importlib
import paddle.nn as nn
from paddle.jit import to_static
from paddle.static import InputSpec
from . import backbone, gears
from .backbone import *
from . import backbone as backbone_zoo
from .gears import build_gear
from .utils import *
from .backbone.base.theseus_layer import TheseusLayer
......@@ -35,20 +35,28 @@ def build_model(config, mode="train"):
arch_config = copy.deepcopy(config["Arch"])
model_type = arch_config.pop("name")
use_sync_bn = arch_config.pop("use_sync_bn", False)
mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**arch_config)
if hasattr(backbone_zoo, model_type):
model = ClassModel(model_type, **arch_config)
else:
model = getattr(sys.modules[__name__], model_type)("ClassModel",
**arch_config)
if use_sync_bn:
if config["Global"]["device"] == "gpu":
arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
else:
msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
logger.warning(msg)
if isinstance(arch, TheseusLayer):
prune_model(config, arch)
quantize_model(config, arch, mode)
if isinstance(model, TheseusLayer):
prune_model(config, model)
quantize_model(config, model, mode)
return arch
# set @to_static for benchmark, skip this by default.
model = apply_to_static(config, model)
return model
def apply_to_static(config, model):
......@@ -65,12 +73,29 @@ def apply_to_static(config, model):
return model
# TODO(gaotingquan): export model
class ClassModel(TheseusLayer):
def __init__(self, model_type, **config):
super().__init__()
if model_type == "ClassModel":
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
else:
backbone_name = model_type
backbone_config = config
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
def forward(self, batch):
x, label = batch[0], batch[1]
return self.backbone(x)
class RecModel(TheseusLayer):
def __init__(self, **config):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
self.backbone = eval(backbone_name)(**backbone_config)
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
self.head_feature_from = config.get('head_feature_from', 'neck')
if "BackboneStopLayer" in config:
......@@ -87,8 +112,8 @@ class RecModel(TheseusLayer):
else:
self.head = None
def forward(self, x, label=None):
def forward(self, batch):
x, label = batch[0], batch[1]
out = dict()
x = self.backbone(x)
out["backbone"] = x
......@@ -140,7 +165,8 @@ class DistillationModel(nn.Layer):
load_dygraph_pretrain(
self.model_name_list[idx], path=pretrained)
def forward(self, x, label=None):
def forward(self, batch):
x, label = batch[0], batch[1]
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
if label is None:
......@@ -158,7 +184,8 @@ class AttentionModel(DistillationModel):
**kargs):
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
def forward(self, x, label=None):
def forward(self, batch):
x, label = batch[0], batch[1]
result_dict = dict()
out = x
for idx, model_name in enumerate(self.model_name_list):
......@@ -168,4 +195,4 @@ class AttentionModel(DistillationModel):
else:
out = self.model_list[idx](out, label)
result_dict.update(out)
return result_dict
\ No newline at end of file
return result_dict
......@@ -28,7 +28,6 @@ from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
from ppcls.arch import apply_to_static
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
......@@ -57,18 +56,10 @@ class Engine(object):
# init logger
init_logger(self.config, mode=mode)
print_config(config)
# for visualdl
self.vdl_writer = self._init_vdl()
# is_rec
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
False):
self.is_rec = True
else:
self.is_rec = False
# init train_func and eval_func
self.train_mode = self.config["Global"].get("train_mode", None)
if self.train_mode is None:
......@@ -108,8 +99,6 @@ class Engine(object):
# build model
self.model = build_model(self.config, self.mode)
# set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model)
# load_pretrain
self._init_pretrained()
......@@ -125,6 +114,8 @@ class Engine(object):
# for distributed
self._init_dist()
print_config(config)
def train(self):
assert self.mode == "train"
print_batch_step = self.config['Global']['print_batch_step']
......
......@@ -55,10 +55,10 @@ def train_epoch(engine, epoch_id, print_batch_step):
"flatten_contiguous_range", "greater_than"
},
level=amp_level):
out = forward(engine, batch)
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
else:
out = forward(engine, batch)
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
# loss
......@@ -104,10 +104,3 @@ def train_epoch(engine, epoch_id, print_batch_step):
if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
engine.lr_sch[i].step()
def forward(engine, batch):
if not engine.is_rec:
return engine.model(batch[0])
else:
return engine.model(batch[0], batch[1])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册