提交 21e76d08 编写于 作者: T tianyi1997 提交者: HydrogenSulfate

Modify codes based on reviews

上级 d79fb66e
......@@ -42,50 +42,47 @@ def ResNet50_adaptive_max_pool2d(pretrained=False, use_ssld=False, **kwargs):
return model
class BINGate(nn.Layer):
def __init__(self, num_features):
super().__init__()
self.gate = self.create_parameter(
shape=[num_features],
default_initializer=nn.initializer.Constant(1.0))
self.add_parameter("gate", self.gate)
def forward(self, opt={}):
flag_update = 'lr_gate' in opt and \
opt.get('enable_inside_update', False)
if flag_update and self.gate.grad is not None: # update gate
lr = opt['lr_gate'] * self.gate.optimize_attr.get('learning_rate',
1.0)
gate = self.gate - lr * self.gate.grad
gate.clip_(min=0, max=1)
else:
gate = self.gate
return gate
def clip_gate(self):
self.gate.set_value(self.gate.clip(0, 1))
class MetaBN(nn.BatchNorm2D):
def forward(self, inputs, opt={}):
mode = opt.get("bn_mode", "general") if self.training else "eval"
if mode == "general": # update, but not apply running_mean/var
result = F.batch_norm(inputs, self._mean, self._variance,
self.weight, self.bias, self.training,
self._momentum, self._epsilon)
elif mode == "hold": # not update, not apply running_mean/var
result = F.batch_norm(
inputs,
paddle.mean(
inputs, axis=(0, 2, 3)),
paddle.var(inputs, axis=(0, 2, 3)),
self.weight,
self.bias,
self.training,
self._momentum,
self._epsilon)
elif mode == "eval": # fix and apply running_mean/var,
if self._mean is None:
def ResNet50_metabin(pretrained=False,
use_ssld=False,
bias_lr_factor=1.0,
gate_lr_factor=1.0,
**kwargs):
"""
ResNet50 which replaces all `bn` layers with MetaBIN
reference: https://arxiv.org/abs/2011.14670
"""
class BINGate(nn.Layer):
def __init__(self, num_features):
super().__init__()
self.gate = self.create_parameter(
shape=[num_features],
default_initializer=nn.initializer.Constant(1.0))
self.add_parameter("gate", self.gate)
def forward(self, opt={}):
flag_update = 'lr_gate' in opt and \
opt.get('enable_inside_update', False)
if flag_update and self.gate.grad is not None: # update gate
lr = opt['lr_gate'] * self.gate.optimize_attr.get(
'learning_rate', 1.0)
gate = self.gate - lr * self.gate.grad
gate.clip_(min=0, max=1)
else:
gate = self.gate
return gate
def clip_gate(self):
self.gate.set_value(self.gate.clip(0, 1))
class MetaBN(nn.BatchNorm2D):
def forward(self, inputs, opt={}):
mode = opt.get("bn_mode", "general") if self.training else "eval"
if mode == "general": # update, but not apply running_mean/var
result = F.batch_norm(inputs, self._mean, self._variance,
self.weight, self.bias, self.training,
self._momentum, self._epsilon)
elif mode == "hold": # not update, not apply running_mean/var
result = F.batch_norm(
inputs,
paddle.mean(
......@@ -93,75 +90,75 @@ class MetaBN(nn.BatchNorm2D):
paddle.var(inputs, axis=(0, 2, 3)),
self.weight,
self.bias,
True,
self.training,
self._momentum,
self._epsilon)
else:
result = F.batch_norm(inputs, self._mean, self._variance,
self.weight, self.bias, False,
self._momentum, self._epsilon)
return result
class MetaBIN(nn.Layer):
"""
MetaBIN (Meta Batch-Instance Normalization)
reference: https://arxiv.org/abs/2011.14670
"""
def __init__(self, num_features):
super().__init__()
self.batch_norm = MetaBN(
num_features=num_features, use_global_stats=True)
self.instance_norm = nn.InstanceNorm2D(num_features=num_features)
self.gate = BINGate(num_features=num_features)
self.opt = defaultdict()
def forward(self, inputs):
out_bn = self.batch_norm(inputs, self.opt)
out_in = self.instance_norm(inputs)
gate = self.gate(self.opt)
gate = gate.unsqueeze([0, -1, -1])
out = out_bn * gate + out_in * (1 - gate)
return out
def reset_opt(self):
self.opt = defaultdict()
def setup_opt(self, opt):
elif mode == "eval": # fix and apply running_mean/var,
if self._mean is None:
result = F.batch_norm(
inputs,
paddle.mean(
inputs, axis=(0, 2, 3)),
paddle.var(inputs, axis=(0, 2, 3)),
self.weight,
self.bias,
True,
self._momentum,
self._epsilon)
else:
result = F.batch_norm(inputs, self._mean, self._variance,
self.weight, self.bias, False,
self._momentum, self._epsilon)
return result
class MetaBIN(nn.Layer):
"""
enable_inside_update: enable inside updating for `gate` in MetaBIN
lr_gate: learning rate of `gate` during meta-train phase
bn_mode: control the running stats & updating of BN
MetaBIN (Meta Batch-Instance Normalization)
reference: https://arxiv.org/abs/2011.14670
"""
self.check_opt(opt)
self.opt = copy.deepcopy(opt)
@classmethod
def check_opt(cls, opt):
assert isinstance(opt, dict), \
TypeError('Got the wrong type of `opt`. Please use `dict` type.')
if opt.get('enable_inside_update', False) and 'lr_gate' not in opt:
raise RuntimeError('Missing `lr_gate` in opt.')
assert isinstance(opt.get('lr_gate', 1.0), float), \
TypeError('Got the wrong type of `lr_gate`. Please use `float` type.')
assert isinstance(opt.get('enable_inside_update', True), bool), \
TypeError('Got the wrong type of `enable_inside_update`. Please use `bool` type.')
assert opt.get('bn_mode', "general") in ["general", "hold", "eval"], \
TypeError('Got the wrong value of `bn_mode`.')
def ResNet50_metabin(pretrained=False,
use_ssld=False,
bias_lr_factor=1.0,
gate_lr_factor=1.0,
**kwargs):
"""
ResNet50 which replaces all `bn` layer with MetaBIN
reference: https://arxiv.org/abs/2011.14670
"""
def __init__(self, num_features):
super().__init__()
self.batch_norm = MetaBN(
num_features=num_features, use_global_stats=True)
self.instance_norm = nn.InstanceNorm2D(num_features=num_features)
self.gate = BINGate(num_features=num_features)
self.opt = defaultdict()
def forward(self, inputs):
out_bn = self.batch_norm(inputs, self.opt)
out_in = self.instance_norm(inputs)
gate = self.gate(self.opt)
gate = gate.unsqueeze([0, -1, -1])
out = out_bn * gate + out_in * (1 - gate)
return out
def reset_opt(self):
self.opt = defaultdict()
def setup_opt(self, opt):
"""
enable_inside_update: enable inside updating for `gate` in MetaBIN
lr_gate: learning rate of `gate` during meta-train phase
bn_mode: control the running stats & updating of BN
"""
self.check_opt(opt)
self.opt = copy.deepcopy(opt)
@classmethod
def check_opt(cls, opt):
assert isinstance(opt, dict), \
TypeError('Got the wrong type of `opt`. Please use `dict` type.')
if opt.get('enable_inside_update', False) and 'lr_gate' not in opt:
raise RuntimeError('Missing `lr_gate` in opt.')
assert isinstance(opt.get('lr_gate', 1.0), float), \
TypeError('Got the wrong type of `lr_gate`. Please use `float` type.')
assert isinstance(opt.get('enable_inside_update', True), bool), \
TypeError('Got the wrong type of `enable_inside_update`. Please use `bool` type.')
assert opt.get('bn_mode', "general") in ["general", "hold", "eval"], \
TypeError('Got the wrong value of `bn_mode`.')
def bn2metabin(bn, pattern):
metabin = MetaBIN(bn.weight.shape[0])
......
......@@ -20,6 +20,10 @@ Global:
save_inference_dir: "./inference"
train_mode: 'metabin'
AMP:
scale_loss: 65536
use_dynamic_loss_scaling: True
# model architecture
Arch:
name: "RecModel"
......@@ -33,10 +37,6 @@ Arch:
Neck:
name: BNNeck
num_features: &feat_dim 2048
weight_attr:
initializer:
name: Constant
value: 1.0
Head:
name: "FC"
embedding_size: *feat_dim
......@@ -271,10 +271,6 @@ Optimizer:
by_epoch: False
last_epoch: 0
AMP:
scale_loss: 65536
use_dynamic_loss_scaling: True
Metric:
Eval:
- Recallk:
......
......@@ -27,9 +27,9 @@ class DomainShuffleSampler(Sampler):
"""
def __init__(self,
dataset: str,
batch_size: int,
num_instances: int,
dataset,
batch_size,
num_instances,
camera_to_domain=True):
self.dataset = dataset
self.batch_size = batch_size
......@@ -40,8 +40,12 @@ class DomainShuffleSampler(Sampler):
self.pid_domain = defaultdict(list)
self.pid_index = defaultdict(list)
# data_source: [(img_path, pid, camera, domain), ...] (camera_to_domain = True)
data_source = zip(dataset.images, dataset.labels, dataset.cameras,
dataset.cameras)
if camera_to_domain:
data_source = zip(dataset.images, dataset.labels, dataset.cameras,
dataset.cameras)
else:
data_source = zip(dataset.images, dataset.labels, dataset.cameras,
dataset.domains)
for index, info in enumerate(data_source):
domainid = info[3]
if camera_to_domain:
......
......@@ -204,7 +204,7 @@ class MSMT17(Dataset):
return len(set(self.labels))
class DukeMTMC(Dataset):
class DukeMTMC(Market1501):
"""
DukeMTMC-reID.
......@@ -221,28 +221,6 @@ class DukeMTMC(Dataset):
"""
_dataset_dir = 'dukemtmc/DukeMTMC-reID'
def __init__(self,
image_root,
cls_label_path,
transform_ops=None,
backend="cv2"):
self._img_root = image_root
self._cls_path = cls_label_path # the sub folder in the dataset
self._dataset_dir = osp.join(image_root, self._dataset_dir,
self._cls_path)
self._check_before_run()
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self.backend = backend
self._dtype = paddle.get_default_dtype()
self._load_anno(relabel=True if 'train' in self._cls_path else False)
def _check_before_run(self):
"""Check if the file is available before going deeper"""
if not osp.exists(self._dataset_dir):
raise RuntimeError("'{}' is not available".format(
self._dataset_dir))
def _load_anno(self, relabel=False):
img_paths = glob.glob(osp.join(self._dataset_dir, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d+)')
......@@ -270,29 +248,6 @@ class DukeMTMC(Dataset):
self.num_pids, self.num_imgs, self.num_cams = get_imagedata_info(
self.images, self.labels, self.cameras, subfolder=self._cls_path)
def __getitem__(self, idx):
try:
img = Image.open(self.images[idx]).convert('RGB')
if self.backend == "cv2":
img = np.array(img, dtype="float32").astype(np.uint8)
if self._transform_ops:
img = transform(img, self._transform_ops)
if self.backend == "cv2":
img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
def get_imagedata_info(data, labels, cameras, subfolder='train'):
pids, cams = [], []
......
......@@ -24,7 +24,6 @@ from collections import defaultdict
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
from ppcls.utils import profiler
from ppcls.data import build_dataloader
from ppcls.arch.backbone.variant_models.resnet_variant import MetaBIN, BINGate
from ppcls.loss import build_loss
......@@ -74,7 +73,7 @@ def train_epoch_metabin(engine, epoch_id, print_batch_step):
engine.global_step += 1
if engine.global_step == 1: # update model (without gate) to warmup
if engine.global_step == 1: # update model (execpt gate) to warmup
for i in range(engine.config["Global"]["warmup_iter"] - 1):
out, basic_loss_dict = basic_update(engine, train_batch)
loss_dict = basic_loss_dict
......@@ -143,14 +142,14 @@ def setup_opt(engine, stage):
opt["bn_mode"] = "hold"
opt["enable_inside_update"] = True
opt["lr_gate"] = norm_lr * cyclic_lr
for layer in engine.model.sublayers():
if isinstance(layer, MetaBIN):
for name, layer in engine.model.backbone.named_sublayers():
if "bn" == name.split('.')[-1]:
layer.setup_opt(opt)
def reset_opt(model):
for layer in model.sublayers():
if isinstance(layer, MetaBIN):
for name, layer in model.backbone.named_sublayers():
if "bn" == name.split('.')[-1]:
layer.reset_opt()
......@@ -176,7 +175,6 @@ def get_meta_data(meta_dataloader_iter, num_domain):
mtrain_batch = None
raise RuntimeError
else:
mtrain_batch = dict()
mtrain_batch = [batch[i][is_mtrain_domain] for i in range(len(batch))]
# mtest_batch
......@@ -185,7 +183,6 @@ def get_meta_data(meta_dataloader_iter, num_domain):
mtest_batch = None
raise RuntimeError
else:
mtest_batch = dict()
mtest_batch = [batch[i][is_mtest_domains] for i in range(len(batch))]
return mtrain_batch, mtest_batch
......@@ -206,8 +203,8 @@ def backward(engine, loss, optimizer):
scaled = engine.scaler.scale(loss)
scaled.backward()
engine.scaler.minimize(optimizer, scaled)
for layer in engine.model.sublayers():
if isinstance(layer, BINGate):
for name, layer in engine.model.backbone.named_sublayers():
if "gate" == name.split('.')[-1]:
layer.clip_gate()
......
......@@ -22,7 +22,6 @@ from paddle.nn import functional as F
from .dist_loss import cosine_similarity
from .celoss import CELoss
from .triplet import TripletLoss
def euclidean_dist(x, y):
......@@ -41,19 +40,29 @@ def hard_example_mining(dist_mat, is_pos, is_neg):
is_pos: positive index with shape [N, M]
is_neg: negative index with shape [N, M]
Returns:
dist_ap: distance(anchor, positive); shape [N]
dist_an: distance(anchor, negative); shape [N]
dist_ap: distance(anchor, positive); shape [N, 1]
dist_an: distance(anchor, negative); shape [N, 1]
"""
inf = float("inf")
def _masked_max(tensor, mask, axis):
masked = paddle.multiply(tensor, mask.astype(tensor.dtype))
neg_inf = paddle.zeros_like(tensor)
neg_inf.stop_gradient = True
neg_inf[paddle.logical_not(mask)] = -inf
return paddle.max(masked + neg_inf, axis=axis, keepdim=True)
def _masked_min(tensor, mask, axis):
masked = paddle.multiply(tensor, mask.astype(tensor.dtype))
pos_inf = paddle.zeros_like(tensor)
pos_inf.stop_gradient = True
pos_inf[paddle.logical_not(mask)] = inf
return paddle.min(masked + pos_inf, axis=axis, keepdim=True)
assert len(dist_mat.shape) == 2
dist_ap = list()
for i in range(dist_mat.shape[0]):
dist_ap.append(paddle.max(dist_mat[i][is_pos[i]]))
dist_ap = paddle.stack(dist_ap)
dist_an = list()
for i in range(dist_mat.shape[0]):
dist_an.append(paddle.min(dist_mat[i][is_neg[i]]))
dist_an = paddle.stack(dist_an)
dist_ap = _masked_max(dist_mat, is_pos, axis=1)
dist_an = _masked_min(dist_mat, is_neg, axis=1)
return dist_ap, dist_an
......
......@@ -257,7 +257,6 @@ class Cyclic(LRBase):
"""Cyclic learning rate decay
Args:
Args:
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册