未验证 提交 afadcf05 编写于 作者: S Steffy-zxf 提交者: GitHub

update get embedding api for tranformer module (#601)

上级 d6763ca2
...@@ -229,6 +229,9 @@ class _TransformerEmbeddingTask(hub.BaseTask): ...@@ -229,6 +229,9 @@ class _TransformerEmbeddingTask(hub.BaseTask):
self.seq_feature = seq_feature self.seq_feature = seq_feature
def _build_net(self): def _build_net(self):
# ClassifyReader will return the seqence length of an input text
self.seq_len = fluid.layers.data(
name="seq_len", shape=[1], dtype='int64', lod_level=0)
return [self.pooled_feature, self.seq_feature] return [self.pooled_feature, self.seq_feature]
def _postprocessing(self, run_states): def _postprocessing(self, run_states):
...@@ -242,6 +245,18 @@ class _TransformerEmbeddingTask(hub.BaseTask): ...@@ -242,6 +245,18 @@ class _TransformerEmbeddingTask(hub.BaseTask):
[batch_pooled_features[i], batch_seq_features[i]]) [batch_pooled_features[i], batch_seq_features[i]])
return results return results
@property
def feed_list(self):
feed_list = [varname
for varname in self._base_feed_list] + [self.seq_len.name]
return feed_list
@property
def fetch_list(self):
fetch_list = [output.name
for output in self.outputs] + [self.seq_len.name]
return fetch_list
class TransformerModule(NLPBaseModule): class TransformerModule(NLPBaseModule):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册