提交 67c508e4 编写于 作者: Z Zeyu Chen

fix senta to senta_bilstm

上级 fcab4a90
python ../../paddlehub/commands/hub.py run senta --input_file test/test.txt python ../../paddlehub/commands/hub.py run senta_bilstm --input_file test/test.txt
...@@ -21,7 +21,7 @@ args = parser.parse_args() ...@@ -21,7 +21,7 @@ args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
# loading Paddlehub senta pretrained model # loading Paddlehub senta pretrained model
module = hub.Module(name="senta") module = hub.Module(name="senta_bilstm")
inputs, outputs, program = module.context(trainable=True) inputs, outputs, program = module.context(trainable=True)
# Sentence classification dataset reader # Sentence classification dataset reader
...@@ -32,13 +32,11 @@ if __name__ == '__main__': ...@@ -32,13 +32,11 @@ if __name__ == '__main__':
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
with fluid.program_guard(program): with fluid.program_guard(program):
# Use "sequence_output" for classification tasks on an entire sentence. sent_feature = outputs["sentence_feature"]
# Use "sequence_outputs" for token-level output.
sequence_output = outputs["sequence_output"]
# Define a classfication finetune task by PaddleHub's API # Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_cls_task( cls_task = hub.create_text_cls_task(
feature=sequence_output, num_classes=dataset.num_labels) feature=sent_feature, num_classes=dataset.num_labels)
# Setup feed list for data feeder # Setup feed list for data feeder
# Must feed all the tensor of senta's module need # Must feed all the tensor of senta's module need
...@@ -69,4 +67,4 @@ if __name__ == '__main__': ...@@ -69,4 +67,4 @@ if __name__ == '__main__':
correct += 1 correct += 1
acc = 1.0 * correct / total acc = 1.0 * correct / total
print("%s\tpredict=%s" % (test_examples[index], pred_v[0][0])) print("%s\tpredict=%s" % (test_examples[index], pred_v[0][0]))
print("accuracy = %f" % acc) print("accuracy = %f" % acc)
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=5
DATASET="chnsenticorp" DATASET="chnsenticorp"
CKPT_DIR="./ckpt_${DATASET}" CKPT_DIR="./ckpt_${DATASET}"
python -u senta_finetune.py \ python -u senta_finetune.py \
--batch_size=24 \ --batch_size=24 \
--use_gpu=False \ --use_gpu=True \
--checkpoint_dir=${CKPT_DIR} \ --checkpoint_dir=${CKPT_DIR} \
--num_epoch=3 --num_epoch=3
...@@ -15,7 +15,7 @@ args = parser.parse_args() ...@@ -15,7 +15,7 @@ args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
# Step1: load Paddlehub senta pretrained model # Step1: load Paddlehub senta pretrained model
module = hub.Module(name="senta") module = hub.Module(name="senta_bilstm")
inputs, outputs, program = module.context(trainable=True) inputs, outputs, program = module.context(trainable=True)
# Step2: Download dataset and use TextClassificationReader to read dataset # Step2: Download dataset and use TextClassificationReader to read dataset
...@@ -24,7 +24,7 @@ if __name__ == '__main__': ...@@ -24,7 +24,7 @@ if __name__ == '__main__':
reader = hub.reader.LACClassifyReader( reader = hub.reader.LACClassifyReader(
dataset=dataset, vocab_path=module.get_vocab_path()) dataset=dataset, vocab_path=module.get_vocab_path())
sent_feature = outputs["sequence_output"] sent_feature = outputs["sentence_feature"]
# Define a classfication finetune task by PaddleHub's API # Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_cls_task( cls_task = hub.create_text_cls_task(
...@@ -35,7 +35,7 @@ if __name__ == '__main__': ...@@ -35,7 +35,7 @@ if __name__ == '__main__':
feed_list = [inputs["words"].name, cls_task.variable('label').name] feed_list = [inputs["words"].name, cls_task.variable('label').name]
strategy = hub.finetune.strategy.AdamWeightDecayStrategy( strategy = hub.finetune.strategy.AdamWeightDecayStrategy(
learning_rate=1e-3, weight_decay=0.01, warmup_proportion=0.01) learning_rate=1e-4, weight_decay=0.01, warmup_proportion=0.05)
config = hub.RunConfig( config = hub.RunConfig(
use_cuda=args.use_gpu, use_cuda=args.use_gpu,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册