diff --git a/python/paddle/v2/dataset/conll05.py b/python/paddle/v2/dataset/conll05.py index 23f5a24a1cea7f665fb65e802e1a7811df78208d..7113202a12df9fb85272c54a09206b026b76784a 100644 --- a/python/paddle/v2/dataset/conll05.py +++ b/python/paddle/v2/dataset/conll05.py @@ -41,6 +41,28 @@ EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7' UNK_IDX = 0 +def load_label_dict(filename): + d = dict() + tag_dict = set() + with open(filename, 'r') as f: + for i, line in enumerate(f): + line = line.strip() + if line.startswith("B-"): + tag_dict.add(line[2:]) + elif line.startswith("I-"): + tag_dict.add(line[2:]) + else: + continue + index = 0 + for tag in tag_dict: + d["B-" + tag] = index + index += 1 + d["I-" + tag] = index + index += 1 + d["O"] = index + return d + + def load_dict(filename): d = dict() with open(filename, 'r') as f: @@ -188,7 +210,7 @@ def get_dict(): verb_dict = load_dict( paddle.v2.dataset.common.download(VERBDICT_URL, 'conll05st', VERBDICT_MD5)) - label_dict = load_dict( + label_dict = load_label_dict( paddle.v2.dataset.common.download(TRGDICT_URL, 'conll05st', TRGDICT_MD5)) return word_dict, verb_dict, label_dict