提交 9740ead0 编写于 作者: W wuzewu

The context interface supports getting all signatures's input and output variables

上级 a65abc82
......@@ -451,7 +451,7 @@ class Module(object):
raise ValueError("This Module is not callable!")
def context(self,
sign_name,
sign_name=None,
for_test=False,
trainable=False,
regularizer=None,
......@@ -463,10 +463,35 @@ class Module(object):
available for BERT/ERNIE module
"""
if sign_name not in self.signatures:
raise KeyError(
"Module did not have a signature with name %s" % sign_name)
signature = self.signatures[sign_name]
if sign_name:
if sign_name not in self.signatures:
raise KeyError(
"Module did not have a signature with name %s" % sign_name)
signature = self.signatures[sign_name]
else:
inputs = [
input for signature in self.signatures.values()
for input in signature.inputs
]
outputs = [
output for signature in self.signatures.values()
for output in signature.outputs
]
feed_names = [
feed_name for signature in self.signatures.values()
for feed_name in signature.feed_names
]
fetch_names = [
fetch_name for signature in self.signatures.values()
for fetch_name in signature.fetch_names
]
signature = create_signature(
name="hub_temp_signature",
inputs=inputs,
outputs=outputs,
feed_names=feed_names,
fetch_names=fetch_names,
for_predict=False)
program = self.program.clone(for_test=for_test)
paddle_helper.remove_feed_fetch_op(program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册