提交 5c399b8e 编写于 作者: K kinghuin 提交者: wuzewu

fix inews bug

上级 62a6b95c
...@@ -49,15 +49,15 @@ class INews(HubDataset): ...@@ -49,15 +49,15 @@ class INews(HubDataset):
def _load_train_examples(self): def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, "train.txt") self.train_file = os.path.join(self.dataset_dir, "train.txt")
self.train_examples = self._read_file(self.train_file) self.train_examples = self._read_file(self.train_file, is_training=True)
def _load_dev_examples(self): def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.txt") self.dev_file = os.path.join(self.dataset_dir, "dev.txt")
self.dev_examples = self._read_file(self.dev_file) self.dev_examples = self._read_file(self.dev_file, is_training=False)
def _load_test_examples(self): def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.txt") self.test_file = os.path.join(self.dataset_dir, "test.txt")
self.test_examples = self._read_file(self.test_file) self.test_examples = self._read_file(self.test_file, is_training=False)
def get_train_examples(self): def get_train_examples(self):
return self.train_examples return self.train_examples
...@@ -78,12 +78,12 @@ class INews(HubDataset): ...@@ -78,12 +78,12 @@ class INews(HubDataset):
""" """
return len(self.get_labels()) return len(self.get_labels())
def _read_file(self, input_file): def _read_file(self, input_file, is_training):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file: with io.open(input_file, "r", encoding="UTF-8") as file:
examples = [] examples = []
for (i, line) in enumerate(file): for (i, line) in enumerate(file):
if i == 0: if i == 0 and is_training:
continue continue
data = line.strip().split("_!_") data = line.strip().split("_!_")
example = InputExample( example = InputExample(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册