#!/usr/bin/env python # -*- coding: utf-8 -*- from utils import UNK, ModelType, TaskType, load_dic, sent2ids, logger, ModelType class Dataset(object): def __init__(self, train_path, test_path, source_dic_path, target_dic_path, model_type=ModelType.RANK): self.train_path = train_path self.test_path = test_path self.source_dic_path = source_dic_path self.target_dic_path = target_dic_path self.model_type = model_type self.source_dic = load_dic(self.source_dic_path) self.target_dic = load_dic(self.target_dic_path) self.record_reader = self._read_classification_record \ if self.model_type == ModelType.CLASSIFICATION \ else self._read_rank_record def train(self): logger.info("[reader] load trainset from %s" % self.train_path) with open(self.train_path) as f: for line_id, line in enumerate(f): yield self.record_reader(line) def test(self): logger.info("[reader] load testset from %s" % self.test_path) with open(self.test_path) as f: for line_id, line in enumerate(f): yield self.record_reader(line) def _read_classification_record(self, line): ''' data format: [TAB] [TAB]