提交 6419c7bd 编写于 作者: 文幕地方's avatar 文幕地方

add stop_gradient to create tensor

上级 f6698a32
...@@ -216,6 +216,8 @@ class SLAHead(nn.Layer): ...@@ -216,6 +216,8 @@ class SLAHead(nn.Layer):
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
structure_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.num_embeddings)) structure_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.num_embeddings))
loc_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.loc_reg_num)) loc_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.loc_reg_num))
structure_preds.stop_gradient = True
loc_preds.stop_gradient = True
if self.training and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_text_length + 1): for i in range(self.max_text_length + 1):
...@@ -223,6 +225,7 @@ class SLAHead(nn.Layer): ...@@ -223,6 +225,7 @@ class SLAHead(nn.Layer):
fea, hidden) fea, hidden)
structure_preds[:, i, :] = structure_step structure_preds[:, i, :] = structure_step
loc_preds[:, i, :] = loc_step loc_preds[:, i, :] = loc_step
else:
pre_chars = paddle.zeros(shape=[batch_size], dtype="int32") pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
max_text_length = paddle.to_tensor(self.max_text_length) max_text_length = paddle.to_tensor(self.max_text_length)
# for export # for export
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册