提交 bac07d15 编写于 作者: Z Zeyu Chen

update Senta

上级 82e1494a
......@@ -26,7 +26,7 @@ if __name__ == '__main__':
# Sentence classification dataset reader
dataset = hub.dataset.ChnSentiCorp()
reader = hub.reader.TextClassificationReader(
reader = hub.reader.LACClassifyReader(
dataset=dataset, vocab_path=module.get_vocab_path())
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
......
export CUDA_VISIBLE_DEVICES=2
export CUDA_VISIBLE_DEVICES=0
# User can select chnsenticorp, nlpcc_dbqa, lcqmc for different task
DATASET="chnsenticorp"
......@@ -6,6 +6,6 @@ CKPT_DIR="./ckpt_${DATASET}"
python -u text_classifier.py \
--batch_size=24 \
--use_gpu=True \
--use_gpu=False \
--checkpoint_dir=${CKPT_DIR} \
--num_epoch=10
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_chnsenticorp/best_model"
python -u predict.py --checkpoint_dir $CKPT_DIR --use_gpu True
python -u predict.py --checkpoint_dir $CKPT_DIR --use_gpu False
......@@ -21,7 +21,7 @@ if __name__ == '__main__':
# Step2: Download dataset and use TextClassificationReader to read dataset
dataset = hub.dataset.ChnSentiCorp()
reader = hub.reader.LACTokenizeReader(
reader = hub.reader.LACClassifyReader(
dataset=dataset, vocab_path=module.get_vocab_path())
sent_feature = outputs["sequence_output"]
......
......@@ -14,5 +14,5 @@
from .nlp_reader import ClassifyReader
from .nlp_reader import SequenceLabelReader
from .nlp_reader import LACTokenizeReader
from .nlp_reader import LACClassifyReader
from .cv_reader import ImageClassificationReader
......@@ -382,7 +382,7 @@ class ExtractEmbeddingReader(BaseReader):
return return_list
class LACTokenizeReader(object):
class LACClassifyReader(object):
def __init__(self, dataset, vocab_path):
self.dataset = dataset
self.lac = hub.Module(name="lac")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册