From d0689412865f405f201a46ff971e32af9568d7df Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 4 Apr 2019 11:42:39 +0800 Subject: [PATCH] fix label map issues for seq labling task --- demo/ernie-classification/finetune_with_hub.py | 2 +- .../finetune_with_hub.py | 0 .../run_fintune_with_hub.sh | 2 +- paddlehub/reader/task_reader.py | 6 +++++- 4 files changed, 7 insertions(+), 3 deletions(-) rename demo/{ernie-seq-label => ernie-seq-labeling}/finetune_with_hub.py (100%) rename demo/{ernie-seq-label => ernie-seq-labeling}/run_fintune_with_hub.sh (90%) diff --git a/demo/ernie-classification/finetune_with_hub.py b/demo/ernie-classification/finetune_with_hub.py index 8d6737ae..f167a638 100644 --- a/demo/ernie-classification/finetune_with_hub.py +++ b/demo/ernie-classification/finetune_with_hub.py @@ -52,7 +52,7 @@ if __name__ == '__main__': strategy=strategy) # loading Paddlehub BERT - module = hub.Module(module_dir=args.hub_module_dir) + module = hub.Module(name="ernie") reader = hub.reader.ClassifyReader( dataset=hub.dataset.ChnSentiCorp(), # download chnsenticorp dataset diff --git a/demo/ernie-seq-label/finetune_with_hub.py b/demo/ernie-seq-labeling/finetune_with_hub.py similarity index 100% rename from demo/ernie-seq-label/finetune_with_hub.py rename to demo/ernie-seq-labeling/finetune_with_hub.py diff --git a/demo/ernie-seq-label/run_fintune_with_hub.sh b/demo/ernie-seq-labeling/run_fintune_with_hub.sh similarity index 90% rename from demo/ernie-seq-label/run_fintune_with_hub.sh rename to demo/ernie-seq-labeling/run_fintune_with_hub.sh index 50d5ec4e..bf4a10f2 100644 --- a/demo/ernie-seq-label/run_fintune_with_hub.sh +++ b/demo/ernie-seq-labeling/run_fintune_with_hub.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=3 +export CUDA_VISIBLE_DEVICES=5 CKPT_DIR="./ckpt" diff --git a/paddlehub/reader/task_reader.py b/paddlehub/reader/task_reader.py index 5f501cd7..abed54c0 100644 --- a/paddlehub/reader/task_reader.py +++ b/paddlehub/reader/task_reader.py @@ -42,7 +42,11 @@ class BaseReader(object): np.random.seed(random_seed) - self.label_map = self.dataset.get_label_map() + # generate label map + self.label_map = {} + for index, label in enumerate(self.dataset.get_labels()): + self.label_map[label] = index + print("Dataset label map = {}".format(self.label_map)) self.current_example = 0 self.current_epoch = 0 -- GitLab