#!/usr/bin/env python # -*- coding: utf-8 -*- from utils import UNK, ModelType, TaskType, load_dic, sent2ids, logger, ModelType class Dataset(object): def __init__(self, train_path, test_path, source_dic_path, target_dic_path, model_type): self.train_path = train_path self.test_path = test_path self.source_dic_path = source_dic_path self.target_dic_path = target_dic_path self.model_type = ModelType(model_type) self.source_dic = load_dic(self.source_dic_path) self.target_dic = load_dic(self.target_dic_path) _record_reader = { ModelType.CLASSIFICATION_MODE: self._read_classification_record, ModelType.REGRESSION_MODE: self._read_regression_record, ModelType.RANK_MODE: self._read_rank_record, } assert isinstance(model_type, ModelType) self.record_reader = _record_reader[model_type.mode] def train(self): ''' Load trainset. ''' logger.info("[reader] load trainset from %s" % self.train_path) with open(self.train_path) as f: for line_id, line in enumerate(f): yield self.record_reader(line) def test(self): ''' Load testset. ''' logger.info("[reader] load testset from %s" % self.test_path) with open(self.test_path) as f: for line_id, line in enumerate(f): yield self.record_reader(line) def _read_classification_record(self, line): ''' data format: [TAB] [TAB]