提交 5825de13 编写于 作者: W wuzewu

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleHub into develop

...@@ -52,7 +52,7 @@ if __name__ == '__main__': ...@@ -52,7 +52,7 @@ if __name__ == '__main__':
strategy=strategy) strategy=strategy)
# loading Paddlehub BERT # loading Paddlehub BERT
module = hub.Module(module_dir=args.hub_module_dir) module = hub.Module(name="ernie")
reader = hub.reader.ClassifyReader( reader = hub.reader.ClassifyReader(
dataset=hub.dataset.ChnSentiCorp(), # download chnsenticorp dataset dataset=hub.dataset.ChnSentiCorp(), # download chnsenticorp dataset
......
export CUDA_VISIBLE_DEVICES=3 export CUDA_VISIBLE_DEVICES=5
CKPT_DIR="./ckpt" CKPT_DIR="./ckpt"
......
...@@ -42,7 +42,11 @@ class BaseReader(object): ...@@ -42,7 +42,11 @@ class BaseReader(object):
np.random.seed(random_seed) np.random.seed(random_seed)
self.label_map = self.dataset.get_label_map() # generate label map
self.label_map = {}
for index, label in enumerate(self.dataset.get_labels()):
self.label_map[label] = index
print("Dataset label map = {}".format(self.label_map))
self.current_example = 0 self.current_example = 0
self.current_epoch = 0 self.current_epoch = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册