提交 fa12cf0b 编写于 作者: T tink2123

polish srn anno

上级 234bb38c
...@@ -28,6 +28,13 @@ gradient_clip = 10 ...@@ -28,6 +28,13 @@ gradient_clip = 10
class SRNPredict(object): class SRNPredict(object):
"""
SRN:
see arxiv: https://arxiv.org/abs/2003.12294
args:
params(dict): the super parameters for network build
"""
def __init__(self, params): def __init__(self, params):
super(SRNPredict, self).__init__() super(SRNPredict, self).__init__()
self.char_num = params['char_num'] self.char_num = params['char_num']
...@@ -39,7 +46,15 @@ class SRNPredict(object): ...@@ -39,7 +46,15 @@ class SRNPredict(object):
self.hidden_dims = params['hidden_dims'] self.hidden_dims = params['hidden_dims']
def pvam(self, inputs, others): def pvam(self, inputs, others):
"""
Parallel visual attention module model
args:
inputs(variable): Feature map extracted from backbone network
others(list): Other location information variables
return: pvam_features
"""
b, c, h, w = inputs.shape b, c, h, w = inputs.shape
conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w]) conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w])
conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1]) conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1])
...@@ -98,6 +113,15 @@ class SRNPredict(object): ...@@ -98,6 +113,15 @@ class SRNPredict(object):
return pvam_features return pvam_features
def gsrm(self, pvam_features, others): def gsrm(self, pvam_features, others):
"""
Global Semantic Reasonging Module
args:
pvam_features(variable): Feature map extracted from pvam
others(list): Other location information variables
return: gsrm_features, word_out, gsrm_out
"""
#===== GSRM Visual-to-semantic embedding block ===== #===== GSRM Visual-to-semantic embedding block =====
b, t, c = pvam_features.shape b, t, c = pvam_features.shape
...@@ -190,7 +214,15 @@ class SRNPredict(object): ...@@ -190,7 +214,15 @@ class SRNPredict(object):
return gsrm_features, word_out, gsrm_out return gsrm_features, word_out, gsrm_out
def vsfd(self, pvam_features, gsrm_features): def vsfd(self, pvam_features, gsrm_features):
"""
Visual-Semantic Fusion Decoder Module
args:
pvam_features(variable): Feature map extracted from pvam
gsrm_features(list): Feature map extracted from gsrm
return: fc_out
"""
#===== Visual-Semantic Fusion Decoder Module ===== #===== Visual-Semantic Fusion Decoder Module =====
b, t, c1 = pvam_features.shape b, t, c1 = pvam_features.shape
b, t, c2 = gsrm_features.shape b, t, c2 = gsrm_features.shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册