diff --git a/python/examples/bert/prepare_model.py b/python/examples/bert/prepare_model.py index 674947b9c0966b142da93f56e6f9b6ab61098a62..70902adf9268d1071c79eb27216dcc2ea9a11a49 100644 --- a/python/examples/bert/prepare_model.py +++ b/python/examples/bert/prepare_model.py @@ -17,11 +17,10 @@ import paddle.fluid as fluid import sys import paddle_serving_client.io as serving_io -#model_name = "bert_chinese_L-12_H-768_A-12" -model_name = sys.argv[1] +model_name = "bert_chinese_L-12_H-768_A-12" module = hub.Module(model_name) inputs, outputs, program = module.context( - trainable=True, max_seq_len=int(sys.argv[2])) + trainable=True, max_seq_len=int(sys.argv[1])) place = fluid.core_avx.CPUPlace() exe = fluid.Executor(place) input_ids = inputs["input_ids"] @@ -38,8 +37,8 @@ feed_var_names = [ target_vars = [pooled_output, sequence_output] serving_io.save_model( - "{}_seq{}_model".format(sys.argv[1], sys.argv[2]), - "{}_seq{}_client".format(sys.argv[1], sys.argv[2]), { + "bert_seq{}_model".format(sys.argv[1]), + "bert_seq{}_client".format(sys.argv[1]), { "input_ids": input_ids, "position_ids": position_ids, "segment_ids": segment_ids,