From c780dd78d0cfceffbc713528934a62a44a16f7a3 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 19 Apr 2019 15:53:53 +0800 Subject: [PATCH] add simple demo --- demo/text-classification/run_classifier.sh | 2 +- demo/text-classification/simple_demo.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/demo/text-classification/run_classifier.sh b/demo/text-classification/run_classifier.sh index 0b53a268..d2d0046b 100644 --- a/demo/text-classification/run_classifier.sh +++ b/demo/text-classification/run_classifier.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=3 # User can select chnsenticorp, nlpcc_dbqa, lcqmc for different task DATASET="chnsenticorp" diff --git a/demo/text-classification/simple_demo.py b/demo/text-classification/simple_demo.py index bb017ca4..c5798f36 100644 --- a/demo/text-classification/simple_demo.py +++ b/demo/text-classification/simple_demo.py @@ -7,7 +7,8 @@ reader = hub.reader.ClassifyReader(hub.dataset.ChnSentiCorp(), module.get_vocab_path()) task = hub.create_text_cls_task(feature=outputs["pooled_output"], num_classes=2) strategy = hub.AdamWeightDecayStrategy(learning_rate=5e-5) -config = hub.RunConfig(use_cuda=True, num_epoch=3, strategy=strategy) +config = hub.RunConfig( + use_cuda=True, num_epoch=3, batch_size=32, strategy=strategy) feed_list = [ inputs["input_ids"].name, inputs["position_ids"].name, inputs["segment_ids"].name, inputs["input_mask"].name, -- GitLab