提交 9740ead0 编写于 作者: W wuzewu

The context interface supports getting all signatures's input and output variables

上级 a65abc82
...@@ -451,7 +451,7 @@ class Module(object): ...@@ -451,7 +451,7 @@ class Module(object):
raise ValueError("This Module is not callable!") raise ValueError("This Module is not callable!")
def context(self, def context(self,
sign_name, sign_name=None,
for_test=False, for_test=False,
trainable=False, trainable=False,
regularizer=None, regularizer=None,
...@@ -463,10 +463,35 @@ class Module(object): ...@@ -463,10 +463,35 @@ class Module(object):
available for BERT/ERNIE module available for BERT/ERNIE module
""" """
if sign_name not in self.signatures: if sign_name:
raise KeyError( if sign_name not in self.signatures:
"Module did not have a signature with name %s" % sign_name) raise KeyError(
signature = self.signatures[sign_name] "Module did not have a signature with name %s" % sign_name)
signature = self.signatures[sign_name]
else:
inputs = [
input for signature in self.signatures.values()
for input in signature.inputs
]
outputs = [
output for signature in self.signatures.values()
for output in signature.outputs
]
feed_names = [
feed_name for signature in self.signatures.values()
for feed_name in signature.feed_names
]
fetch_names = [
fetch_name for signature in self.signatures.values()
for fetch_name in signature.fetch_names
]
signature = create_signature(
name="hub_temp_signature",
inputs=inputs,
outputs=outputs,
feed_names=feed_names,
fetch_names=fetch_names,
for_predict=False)
program = self.program.clone(for_test=for_test) program = self.program.clone(for_test=for_test)
paddle_helper.remove_feed_fetch_op(program) paddle_helper.remove_feed_fetch_op(program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册