提交 aed02a21 编写于 作者: U user3984 提交者: littletomatodonkey

update pefd

上级 7ee8471d
...@@ -24,7 +24,7 @@ class Regressor(nn.Layer): ...@@ -24,7 +24,7 @@ class Regressor(nn.Layer):
def __init__(self, dim_in=1024, dim_out=1024): def __init__(self, dim_in=1024, dim_out=1024):
super(Regressor, self).__init__() super(Regressor, self).__init__()
self.conv = nn.Conv2D(dim_in, dim_out, 1) self.conv = nn.Linear(dim_in, dim_out)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
...@@ -38,29 +38,38 @@ class PEFDLoss(nn.Layer): ...@@ -38,29 +38,38 @@ class PEFDLoss(nn.Layer):
Code reference: https://github.com/chenyd7/PEFD Code reference: https://github.com/chenyd7/PEFD
""" """
def __init__(self, student_channel, teacher_channel, num_projectors=3): def __init__(self,
student_channel,
teacher_channel,
num_projectors=3,
mode="flatten"):
super().__init__() super().__init__()
if num_projectors <= 0: if num_projectors <= 0:
raise ValueError("Number of projectors must be greater than 0.") raise ValueError("Number of projectors must be greater than 0.")
if mode not in ["flatten", "gap"]:
raise ValueError("Mode must be \"flatten\" or \"gap\".")
self.mode = mode
self.projectors = nn.LayerList() self.projectors = nn.LayerList()
for _ in range(num_projectors): for _ in range(num_projectors):
self.projectors.append(Regressor(student_channel, teacher_channel)) self.projectors.append(Regressor(student_channel, teacher_channel))
def forward(self, student_feature, teacher_feature): def forward(self, student_feature, teacher_feature):
if student_feature.shape[2:] != teacher_feature.shape[2:]: if self.mode == "gap":
raise ValueError( student_feature = F.adaptive_avg_pool2d(student_feature, (1, 1))
"Student feature must have the same H and W as teacher feature." teacher_feature = F.adaptive_avg_pool2d(teacher_feature, (1, 1))
)
student_feature = student_feature.flatten(1)
f_t = teacher_feature.flatten(1)
q = len(self.projectors) q = len(self.projectors)
f_s = 0.0 f_s = 0.0
for i in range(q): for i in range(q):
f_s += self.projectors[i](student_feature) f_s += self.projectors[i](student_feature)
f_s = (f_s / q).flatten(1) f_s = f_s / q
f_t = teacher_feature.flatten(1)
# inner product (normalize first and inner product) # inner product (normalize first and inner product)
normft = f_t.pow(2).sum(1, keepdim=True).pow(1. / 2) normft = f_t.pow(2).sum(1, keepdim=True).pow(1. / 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册