提交 c0492e02 编写于 作者: M MissPenguin

refine

上级 de6bc445
...@@ -69,7 +69,7 @@ class BaseModel(nn.Layer): ...@@ -69,7 +69,7 @@ class BaseModel(nn.Layer):
self.return_all_feats = config.get("return_all_feats", False) self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None, mode='Train'): def forward(self, x, data=None):
y = dict() y = dict()
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
...@@ -78,13 +78,7 @@ class BaseModel(nn.Layer): ...@@ -78,13 +78,7 @@ class BaseModel(nn.Layer):
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x y["neck_out"] = x
if data is None: x = self.head(x, targets=data)
x = self.head(x)
else:
if mode == 'Eval' or mode == 'Test':
x = self.head(x, targets=data, mode=mode)
else:
x = self.head(x, targets=data)
y["head_out"] = x y["head_out"] = x
if self.return_all_feats: if self.return_all_feats:
return y return y
......
...@@ -43,7 +43,7 @@ class ClsHead(nn.Layer): ...@@ -43,7 +43,7 @@ class ClsHead(nn.Layer):
initializer=nn.initializer.Uniform(-stdv, stdv)), initializer=nn.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_0.b_0"), ) bias_attr=ParamAttr(name="fc_0.b_0"), )
def forward(self, x): def forward(self, x, targets=None):
x = self.pool(x) x = self.pool(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.fc(x) x = self.fc(x)
......
...@@ -106,7 +106,7 @@ class DBHead(nn.Layer): ...@@ -106,7 +106,7 @@ class DBHead(nn.Layer):
def step_function(self, x, y): def step_function(self, x, y):
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y))) return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
def forward(self, x): def forward(self, x, targets=None):
shrink_maps = self.binarize(x) shrink_maps = self.binarize(x)
if not self.training: if not self.training:
return {'maps': shrink_maps} return {'maps': shrink_maps}
......
...@@ -109,7 +109,7 @@ class EASTHead(nn.Layer): ...@@ -109,7 +109,7 @@ class EASTHead(nn.Layer):
act=None, act=None,
name="f_geo") name="f_geo")
def forward(self, x): def forward(self, x, targets=None):
f_det = self.det_conv1(x) f_det = self.det_conv1(x)
f_det = self.det_conv2(f_det) f_det = self.det_conv2(f_det)
f_score = self.score_conv(f_det) f_score = self.score_conv(f_det)
......
...@@ -116,7 +116,7 @@ class SASTHead(nn.Layer): ...@@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
self.head1 = SAST_Header1(in_channels) self.head1 = SAST_Header1(in_channels)
self.head2 = SAST_Header2(in_channels) self.head2 = SAST_Header2(in_channels)
def forward(self, x): def forward(self, x, targets=None):
f_score, f_border = self.head1(x) f_score, f_border = self.head1(x)
f_tvo, f_tco = self.head2(x) f_tvo, f_tco = self.head2(x)
......
...@@ -220,7 +220,7 @@ class PGHead(nn.Layer): ...@@ -220,7 +220,7 @@ class PGHead(nn.Layer):
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)), weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
bias_attr=False) bias_attr=False)
def forward(self, x): def forward(self, x, targets=None):
f_score = self.conv_f_score1(x) f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score) f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score) f_score = self.conv_f_score3(f_score)
......
...@@ -44,7 +44,7 @@ class CTCHead(nn.Layer): ...@@ -44,7 +44,7 @@ class CTCHead(nn.Layer):
bias_attr=bias_attr) bias_attr=bias_attr)
self.out_channels = out_channels self.out_channels = out_channels
def forward(self, x, labels=None): def forward(self, x, targets=None):
predicts = self.fc(x) predicts = self.fc(x)
if not self.training: if not self.training:
predicts = F.softmax(predicts, axis=2) predicts = F.softmax(predicts, axis=2)
......
...@@ -53,7 +53,7 @@ class TableAttentionHead(nn.Layer): ...@@ -53,7 +53,7 @@ class TableAttentionHead(nn.Layer):
input_ont_hot = F.one_hot(input_char, onehot_dim) input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot return input_ont_hot
def forward(self, inputs, targets=None, mode='Train'): def forward(self, inputs, targets=None):
# if and else branch are both needed when you want to assign a variable # if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work. # if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1] fea = inputs[-1]
...@@ -67,7 +67,7 @@ class TableAttentionHead(nn.Layer): ...@@ -67,7 +67,7 @@ class TableAttentionHead(nn.Layer):
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = [] output_hiddens = []
if mode == 'Train' and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_elem_length+1): for i in range(self.max_elem_length+1):
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
......
...@@ -81,7 +81,7 @@ def main(config, device, logger, vdl_writer): ...@@ -81,7 +81,7 @@ def main(config, device, logger, vdl_writer):
batch = transform(data, ops) batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0) images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
preds = model(images, data=None, mode='Test') preds = model(images)
post_result = post_process_class(preds) post_result = post_process_class(preds)
res_html_code = post_result['res_html_code'] res_html_code = post_result['res_html_code']
res_loc = post_result['res_loc'] res_loc = post_result['res_loc']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册