diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index bbb52c018bcab0b0376554a8d3cfc7914b7e001a..9b1212e7252fa9f2641600aff5148de1f9c72ea8 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -451,7 +451,7 @@ class Module(object): raise ValueError("This Module is not callable!") def context(self, - sign_name, + sign_name=None, for_test=False, trainable=False, regularizer=None, @@ -463,10 +463,35 @@ class Module(object): available for BERT/ERNIE module """ - if sign_name not in self.signatures: - raise KeyError( - "Module did not have a signature with name %s" % sign_name) - signature = self.signatures[sign_name] + if sign_name: + if sign_name not in self.signatures: + raise KeyError( + "Module did not have a signature with name %s" % sign_name) + signature = self.signatures[sign_name] + else: + inputs = [ + input for signature in self.signatures.values() + for input in signature.inputs + ] + outputs = [ + output for signature in self.signatures.values() + for output in signature.outputs + ] + feed_names = [ + feed_name for signature in self.signatures.values() + for feed_name in signature.feed_names + ] + fetch_names = [ + fetch_name for signature in self.signatures.values() + for fetch_name in signature.fetch_names + ] + signature = create_signature( + name="hub_temp_signature", + inputs=inputs, + outputs=outputs, + feed_names=feed_names, + fetch_names=fetch_names, + for_predict=False) program = self.program.clone(for_test=for_test) paddle_helper.remove_feed_fetch_op(program)