提交 a9b84325 编写于 作者: T tianyi1997 提交者: HydrogenSulfate

fix: wrong base class & simplify train func

上级 f02a4630
......@@ -22,7 +22,7 @@ import paddle.nn.functional as F
from ..utils import get_param_attr_dict
class MetaBN1D(nn.BatchNorm2D):
class MetaBN1D(nn.BatchNorm1D):
def forward(self, inputs, opt={}):
mode = opt.get("bn_mode", "general") if self.training else "eval"
if mode == "general": # update, but not apply running_mean/var
......
......@@ -141,15 +141,15 @@ def setup_opt(engine, stage):
opt["bn_mode"] = "hold"
opt["enable_inside_update"] = True
opt["lr_gate"] = norm_lr * cyclic_lr
for name, layer in engine.model.backbone.named_sublayers():
if "bn" == name.split('.')[-1]:
for layer in engine.model.backbone.sublayers():
if type_name(layer) == "MetaBIN":
layer.setup_opt(opt)
engine.model.neck.setup_opt(opt)
def reset_opt(model):
for name, layer in model.backbone.named_sublayers():
if "bn" == name.split('.')[-1]:
for layer in model.backbone.sublayers():
if type_name(layer) == "MetaBIN":
layer.reset_opt()
model.neck.reset_opt()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册