提交 87f508e9 编写于 作者: W weishengyu

dbg

上级 b662ed34
...@@ -39,6 +39,7 @@ class GalleryLayer(paddle.nn.Layer): ...@@ -39,6 +39,7 @@ class GalleryLayer(paddle.nn.Layer):
gallery_docs.append(ori_line.strip()) gallery_docs.append(ori_line.strip())
gallery_labels.append(line[1].strip()) gallery_labels.append(line[1].strip())
self.gallery_layer = paddle.nn.Linear(embedding_size, len(self.gallery_images), bias_attr=False) self.gallery_layer = paddle.nn.Linear(embedding_size, len(self.gallery_images), bias_attr=False)
self.gallery_layer.skip_quant = True
def forward(self, x, label=None): def forward(self, x, label=None):
x = paddle.nn.functional.normalize(x) x = paddle.nn.functional.normalize(x)
...@@ -63,8 +64,8 @@ class GalleryLayer(paddle.nn.Layer): ...@@ -63,8 +64,8 @@ class GalleryLayer(paddle.nn.Layer):
for j in range(batch_index): for j in range(batch_index):
feature = batch_feature[j] feature = batch_feature[j]
norm_feature = paddle.nn.functional.normalize(feature, axis=0) norm_feature = paddle.nn.functional.normalize(feature, axis=0)
gallery_feature[i - batch_index + j] = norm_feature gallery_feature[i - batch_index + j + 1] = norm_feature
self.gallery_layer.set_state_dict({"weight": gallery_feature.T}) self.gallery_layer.set_state_dict({"_layer.weight": gallery_feature.T})
def export_fuse_model(configs): def export_fuse_model(configs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册