提交 a271703a 编写于 作者: S Steffy-zxf 提交者: wuzewu

Update the senta demo (#56)

上级 3c85dfe5
......@@ -15,8 +15,9 @@ import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
args = parser.parse_args()
# yapf: enable.
......@@ -30,42 +31,39 @@ if __name__ == '__main__':
reader = hub.reader.LACClassifyReader(
dataset=dataset, vocab_path=module.get_vocab_path())
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
with fluid.program_guard(program):
sent_feature = outputs["sentence_feature"]
strategy = hub.AdamWeightDecayStrategy(
weight_decay=0.01,
warmup_proportion=0.1,
learning_rate=5e-5,
lr_scheduler="linear_decay",
optimizer_name="adam")
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_cls_task(
feature=sent_feature, num_classes=dataset.num_labels)
config = hub.RunConfig(
use_data_parallel=False,
use_pyreader=args.use_pyreader,
use_cuda=args.use_gpu,
batch_size=1,
enable_memory_optim=False,
checkpoint_dir=args.checkpoint_dir,
strategy=strategy)
# Setup feed list for data feeder
# Must feed all the tensor of senta's module need
feed_list = [inputs["words"].name, cls_task.variable('label').name]
sent_feature = outputs["sentence_feature"]
# classificatin probability tensor
probs = cls_task.variable("probs")
feed_list = [inputs["words"].name]
pred = fluid.layers.argmax(probs, axis=1)
cls_task = hub.TextClassifierTask(
data_reader=reader,
feature=sent_feature,
feed_list=feed_list,
num_classes=dataset.num_labels,
config=config)
# load best model checkpoint
fluid.io.load_persistables(exe, args.checkpoint_dir)
data = ["这家餐厅很好吃", "这部电影真的很差劲"]
inference_program = program.clone(for_test=True)
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
test_reader = reader.data_generator(phase='test', shuffle=False)
test_examples = dataset.get_test_examples()
total = 0
correct = 0
for index, batch in enumerate(test_reader()):
pred_v = exe.run(
feed=data_feeder.feed(batch),
fetch_list=[pred.name],
program=inference_program)
total += 1
if (pred_v[0][0] == int(test_examples[index].label)):
correct += 1
acc = 1.0 * correct / total
print("%s\tpredict=%s" % (test_examples[index], pred_v[0][0]))
print("accuracy = %f" % acc)
results = cls_task.predict(data=data)
index = 0
for batch_result in results:
batch_result = np.argmax(batch_result, axis=2)[0]
for result in batch_result:
print("%s\tpredict=%s" % (data[index], result))
index += 1
export CUDA_VISIBLE_DEVICES=5
export CUDA_VISIBLE_DEVICES=0
DATASET="chnsenticorp"
CKPT_DIR="./ckpt_${DATASET}"
......
......@@ -8,7 +8,7 @@ import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
args = parser.parse_args()
......@@ -19,7 +19,7 @@ if __name__ == '__main__':
module = hub.Module(name="senta_bilstm")
inputs, outputs, program = module.context(trainable=True)
# Step2: Download dataset and use TextClassificationReader to read dataset
# Step2: Download dataset and use LACClassifyReader to read dataset
dataset = hub.dataset.ChnSentiCorp()
reader = hub.reader.LACClassifyReader(
......@@ -27,13 +27,9 @@ if __name__ == '__main__':
sent_feature = outputs["sentence_feature"]
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_cls_task(
feature=sent_feature, num_classes=dataset.num_labels)
# Setup feed list for data feeder
# Must feed all the tensor of senta's module need
feed_list = [inputs["words"].name, cls_task.variable('label').name]
feed_list = [inputs["words"].name]
strategy = hub.finetune.strategy.AdamWeightDecayStrategy(
learning_rate=1e-4, weight_decay=0.01, warmup_proportion=0.05)
......@@ -45,7 +41,14 @@ if __name__ == '__main__':
checkpoint_dir=args.checkpoint_dir,
strategy=strategy)
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.TextClassifierTask(
data_reader=reader,
feature=sent_feature,
feed_list=feed_list,
num_classes=dataset.num_labels,
config=config)
# Finetune and evaluate by PaddleHub's API
# will finish training, evaluation, testing, save model automatically
hub.finetune_and_eval(
task=cls_task, data_reader=reader, feed_list=feed_list, config=config)
cls_task.finetune_and_eval()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册