提交 24788265 编写于 作者: W wuzewu

Fix module compatibility issues

上级 c4f19c8e
......@@ -48,7 +48,11 @@ sys.modules['paddlehub.common.logger'] = log
sys.modules['paddlehub.common.paddle_helper'] = paddle_utils
sys.modules['paddlehub.common.utils'] = utils
sys.modules['paddlehub.reader'] = task
sys.modules['paddlehub.reader.batching'] = task.batch
AdamWeightDecayStrategy = lambda: 0
ULMFiTStrategy = lambda params_layer=0: 0
common = EasyDict(paddle_helper=paddle_utils)
dataset = EasyDict(Couplet=couplet.Couplet)
AdamWeightDecayStrategy = lambda: 0
finetune = EasyDict(strategy=EasyDict(ULMFiTStrategy=ULMFiTStrategy))
logger = EasyDict(logger=log.logger)
......@@ -118,8 +118,8 @@ class ModuleV1(object):
op._set_attr('op_callstack', [''])
@paddle_utils.run_in_static_mode
def context(self, signature: str = None, for_test: bool = False,
trainable: bool = True) -> Tuple[dict, dict, paddle.static.Program]:
def context(self, signature: str = None, for_test: bool = False, trainable: bool = True,
max_seq_len: int = 128) -> Tuple[dict, dict, paddle.static.Program]:
'''Get module context information, including graph structure and graph input and output variables.'''
program = self.program.clone(for_test=for_test)
paddle_utils.remove_feed_fetch_op(program)
......@@ -141,8 +141,27 @@ class ModuleV1(object):
for param in program.all_parameters():
param.trainable = trainable
# The bert series model saved by ModuleV1 sets max_seq_len to 512 by default. We need to adjust max_seq_len
# according to the parameters in actual use.
if 'bert' in self.name or self.name.startswith('ernie'):
self._update_bert_max_seq_len(program, feed_dict, max_seq_len)
return feed_dict, fetch_dict, program
def _update_bert_max_seq_len(self, program: paddle.static.Program, feed_dict: dict, max_seq_len: int = 128):
MAX_SEQ_LENGTH = 512
if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
raise ValueError("max_seq_len({}) should be in the range of [1, {}]".format(max_seq_len, MAX_SEQ_LENGTH))
log.logger.info("Set maximum sequence length of input tensor to {}".format(max_seq_len))
if self.name.startswith("ernie_v2"):
feed_list = ["input_ids", "position_ids", "segment_ids", "input_mask", "task_ids"]
else:
feed_list = ["input_ids", "position_ids", "segment_ids", "input_mask"]
for tensor_name in feed_list:
seq_tensor_shape = [-1, max_seq_len, 1]
log.logger.info("The shape of input tensor[{}] set to {}".format(tensor_name, seq_tensor_shape))
program.global_block().var(feed_dict[tensor_name].name).desc.set_shape(seq_tensor_shape)
@paddle_utils.run_in_static_mode
def __call__(self, sign_name: str, data: dict, use_gpu: bool = False, batch_size: int = 1, **kwargs):
'''Call the specified signature function for prediction.'''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册