未验证 提交 885432bf 编写于 作者: X xiemoyuan 提交者: GitHub

Rename 'decoder' to 'decode'. (#5056)

上级 3698045c
...@@ -326,7 +326,7 @@ class Plato2InferModel(nn.Layer): ...@@ -326,7 +326,7 @@ class Plato2InferModel(nn.Layer):
enc_out, new_caches = self.plato2_encoder( enc_out, new_caches = self.plato2_encoder(
caches, token_ids, type_ids, pos_ids, generation_mask, latent_emb) caches, token_ids, type_ids, pos_ids, generation_mask, latent_emb)
pred_ids = self.decoder(inputs, new_caches) pred_ids = self.decode(inputs, new_caches)
nsp_inputs = self.gen_nsp_input(token_ids, pred_ids) nsp_inputs = self.gen_nsp_input(token_ids, pred_ids)
# [-1, 2] # [-1, 2]
...@@ -334,7 +334,7 @@ class Plato2InferModel(nn.Layer): ...@@ -334,7 +334,7 @@ class Plato2InferModel(nn.Layer):
return self.get_results(data_id, token_ids, pred_ids, probs) return self.get_results(data_id, token_ids, pred_ids, probs)
def decoder(self, inputs, caches): def decode(self, inputs, caches):
tgt_ids = inputs['tgt_ids'] tgt_ids = inputs['tgt_ids']
tgt_pos = inputs['tgt_pos'] tgt_pos = inputs['tgt_pos']
tgt_generation_mask = inputs['tgt_generation_mask'] tgt_generation_mask = inputs['tgt_generation_mask']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册