提交 e4263533 编写于 作者: K kinghuin 提交者: wuzewu

update demo (#282)

* optimize demo and cv reader
上级 82c4f6e9
......@@ -168,12 +168,4 @@ if __name__ == '__main__':
"19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"
]
index = 0
run_states = elmo_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
for result in batch_result:
print("%s\tpredict=%s" % (data[index], result))
index += 1
print(elmo_task.predict(data=data, return_result=True))
......@@ -73,19 +73,7 @@ def predict(args):
config=config)
data = ["./test/test_img_daisy.jpg", "./test/test_img_roses.jpg"]
label_map = dataset.label_dict()
index = 0
# get classification result
run_states = task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
for result in batch_result:
index += 1
result = label_map[result]
print("input %i is %s, and the predict result is %s" %
(index, data[index - 1], result))
print(task.predict(data=data, return_result=True))
if __name__ == "__main__":
......
......@@ -96,15 +96,4 @@ if __name__ == '__main__':
]
index = 0
run_states = multi_label_cls_task.predict(data=data)
all_result = []
for batch_state in run_states:
batch_result = batch_state.run_results
for sample_id in range(len(batch_result[0])):
sample_result = []
for category_id in range(dataset.num_labels):
sample_category_prob = batch_result[category_id][sample_id]
sample_result.append(np.argmax(sample_category_prob))
all_result.append(sample_result)
print(all_result)
print(multi_label_cls_task.predict(data=data, return_result=True))
......@@ -90,15 +90,4 @@ if __name__ == '__main__':
["北京奥运博物馆的场景效果负责人是谁", "于海勃,美国加利福尼亚大学教授 场景效果负责人 总设计师"],
["北京奥运博物馆的场景效果负责人是谁?", "洪麦恩,清华大学美术学院教授 内容及主展线负责人 总设计师"]]
index = 0
run_states = cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
max_probs = 0
for index, batch_result in enumerate(results):
# get predict index
if max_probs <= batch_result[0][0, 1]:
max_probs = batch_result[0][0, 1]
max_flag = index
print("question:%s\tthe predict answer:%s\t" % (data[max_flag][0],
data[max_flag][1]))
print(cls_task.predict(data=data, return_result=True))
......@@ -113,4 +113,4 @@ if __name__ == '__main__':
# Data to be predicted
data = dataset.dev_examples[97:98]
reading_comprehension_task.predict(data=data)
print(reading_comprehension_task.predict(data=data, return_result=True))
......@@ -94,20 +94,6 @@ if __name__ == '__main__':
config=config)
# Data to be prdicted
data = [[d.text_a, d.text_b] for d in dataset.get_predict_examples()]
data = [[d.text_a, d.text_b] for d in dataset.get_predict_examples()[:3]]
index = 0
run_states = reg_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
if not os.path.exists("output"):
os.makedirs("output")
fout = open(os.path.join("output", "%s.tsv" % args.dataset.upper()), 'w')
fout.write("index\tprediction")
for batch_result in results:
for result in batch_result[0]:
if index < 3:
print("%s\t%s\tpredict=%.3f" % (data[index][0], data[index][1],
result[0]))
fout.write("\n%s\t%.3f" % (index, result[0]))
index += 1
fout.close()
print(reg_task.predict(data=data, return_result=True))
......@@ -59,11 +59,4 @@ if __name__ == '__main__':
data = ["这家餐厅很好吃", "这部电影真的很差劲"]
run_states = cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
index = 0
for batch_result in results:
batch_result = np.argmax(batch_result, axis=2)[0]
for result in batch_result:
print("%s\tpredict=%s" % (data[index], result))
index += 1
print(cls_task.predict(data=data, return_result=True))
......@@ -111,26 +111,4 @@ if __name__ == '__main__':
tmp_data.append(formatted)
data = tmp_data
run_states = seq_label_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for num_batch, batch_results in enumerate(results):
infers = batch_results[0].reshape([-1]).astype(np.int32).tolist()
np_lens = batch_results[1]
for index, np_len in enumerate(np_lens):
labels = infers[index * args.max_seq_len:(index + 1) *
args.max_seq_len]
label_str = ""
count = 0
for label_val in labels:
label_str += inv_label_map[label_val]
count += 1
if count == np_len:
break
# Drop the label results of CLS and SEP Token
print(
"%s\tpredict=%s" %
(data[num_batch * args.batch_size + index][0], label_str[1:-1]))
print(seq_label_task.predict(data=data, return_result=True))
......@@ -163,12 +163,4 @@ if __name__ == '__main__':
# Data to be prdicted
data = [[d.text_a, d.text_b] for d in dataset.get_dev_examples()[:3]]
index = 0
run_states = cls_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
for result in batch_result:
print("%s\tpredict=%s" % (data[index][0], result))
index += 1
print(cls_task.predict(data=data, return_result=True))
......@@ -128,6 +128,7 @@ class BaseDataset(object):
def num_labels(self):
return len(self.label_list)
# To compatibility with the usage of ImageClassificationDataset
def label_dict(self):
return {index: key for index, key in enumerate(self.label_list)}
......
......@@ -142,7 +142,8 @@ class ClassifierTask(BaseTask):
}
except:
raise Exception(
"image-classification does not support return_result now")
"ImageClassificationDataset does not support postprocessing, please use BaseCVDatast instead"
)
results = []
for batch_state in run_states:
batch_result = batch_state.run_results
......@@ -320,6 +321,7 @@ class MultiLabelClassifierTask(ClassifierTask):
def _postprocessing(self, run_states):
results = []
label_list = list(self._base_data_reader.label_map.keys())
for batch_state in run_states:
batch_result = batch_state.run_results
for sample_id in range(len(batch_result[0])):
......@@ -327,6 +329,8 @@ class MultiLabelClassifierTask(ClassifierTask):
for category_id in range(
self._base_data_reader.dataset.num_labels):
sample_category_prob = batch_result[category_id][sample_id]
sample_result.append(np.argmax(sample_category_prob))
sample_category_value = np.argmax(sample_category_prob)
sample_result.append(
{label_list[category_id]: sample_category_value})
results.append(sample_result)
return results
......@@ -623,8 +623,6 @@ class ModuleV1(Module):
"input_ids", "position_ids", "segment_ids", "input_mask",
"task_ids"
]
logger.warning("For %s, it's no necessary to feed task_ids now."
% self.name)
else:
feed_list = [
"input_ids", "position_ids", "segment_ids", "input_mask"
......
import numpy as np
from paddlehub.common.logger import logger
class BaseReader(object):
def __init__(self, dataset, random_seed=None):
......@@ -7,6 +9,18 @@ class BaseReader(object):
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
np.random.seed(random_seed)
# generate label map
self.label_map = {}
try:
for index, label in enumerate(self.dataset.get_labels()):
self.label_map[label] = index
logger.info("Dataset label map = {}".format(self.label_map))
except:
# some dataset like squad, its label_list=None
logger.info(
"Dataset is None or it has not any labels, label map = {}".
format(self.label_map))
def get_train_examples(self):
return self.dataset.get_train_examples()
......
......@@ -63,21 +63,10 @@ class BaseNLPReader(BaseReader):
if self.use_task_id:
logger.warning(
"use_task_id has been de discarded since PaddleHub v1.4.0")
"use_task_id has been de discarded since PaddleHub v1.4.0, it's no necessary to feed task_ids now."
)
self.task_id = 0
# generate label map
self.label_map = {}
try:
for index, label in enumerate(self.dataset.get_labels()):
self.label_map[label] = index
logger.info("Dataset label map = {}".format(self.label_map))
except:
# some dataset like squad, its label_list=None
logger.info(
"Dataset is None or it has not any labels, label map = {}".
format(self.label_map))
self.Record_With_Label_Id = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id'])
......@@ -375,7 +364,7 @@ class SequenceLabelReader(BaseNLPReader):
pad_idx=self.pad_id)
if phase != "predict":
batch_label_ids = [record.label_ids for record in batch_records]
batch_label_ids = [record.label_id for record in batch_records]
padded_label_ids = pad_batch_data(
batch_label_ids,
max_seq_len=self.max_seq_len,
......@@ -520,7 +509,7 @@ class MultiLabelClassifyReader(BaseNLPReader):
pad_idx=self.pad_id)
if phase != "predict":
batch_labels_ids = [record.label_ids for record in batch_records]
batch_labels_ids = [record.label_id for record in batch_records]
num_label = len(self.dataset.get_labels())
batch_labels = np.array(batch_labels_ids).astype("int64").reshape(
[-1, num_label])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册