未验证 提交 4386b841 编写于 作者: G gaotingquan

fix: support bs>1

上级 6d2de979
...@@ -29,10 +29,10 @@ Modules: ...@@ -29,10 +29,10 @@ Modules:
inference_model_dir: "./MobileNetV2_infer" inference_model_dir: "./MobileNetV2_infer"
to_model_names: to_model_names:
image: inputs image: inputs
from_model_names: from_model_indexes:
logits: 0 logits: 0
- name: TopK - name: TopK
type: postprocessor type: postprocessor
k: 10 k: 10
class_id_map_file: "../ppcls/utils/imagenet1k_label_list.txt" class_id_map_file: "../../../ppcls/utils/imagenet1k_label_list.txt"
save_dir: None save_dir: None
\ No newline at end of file
...@@ -25,6 +25,8 @@ Modules: ...@@ -25,6 +25,8 @@ Modules:
- name: PaddlePredictor - name: PaddlePredictor
type: predictor type: predictor
inference_model_dir: ./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/ inference_model_dir: ./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/
from_model_indexes:
boxes: 0
- name: DetPostPro - name: DetPostPro
type: postprocessor type: postprocessor
threshold: 0.2 threshold: 0.2
......
...@@ -28,7 +28,7 @@ Modules: ...@@ -28,7 +28,7 @@ Modules:
inference_model_dir: models/product_ResNet50_vd_aliproduct_v1.0_infer inference_model_dir: models/product_ResNet50_vd_aliproduct_v1.0_infer
to_model_names: to_model_names:
image: x image: x
from_model_names: from_model_indexes:
features: 0 features: 0
- name: FeatureNormalizer - name: FeatureNormalizer
type: postprocessor type: postprocessor
\ No newline at end of file
...@@ -39,26 +39,28 @@ class TopK(BaseProcessor): ...@@ -39,26 +39,28 @@ class TopK(BaseProcessor):
return class_id_map return class_id_map
def process(self, data): def process(self, data):
# TODO(gaotingquan): only support bs==1 when 'connector' is not implemented. logits = data["pred"]["logits"]
probs = data["pred"]["logits"][0] all_results = []
index = probs.argsort(axis=0)[-self.topk:][::-1].astype( for probs in logits:
"int32") if not self.multilabel else np.where( index = probs.argsort(axis=0)[-self.topk:][::-1].astype(
probs >= 0.5)[0].astype("int32") "int32") if not self.multilabel else np.where(
clas_id_list = [] probs >= 0.5)[0].astype("int32")
score_list = [] clas_id_list = []
label_name_list = [] score_list = []
for i in index: label_name_list = []
clas_id_list.append(i.item()) for i in index:
score_list.append(probs[i].item()) clas_id_list.append(i.item())
if self.class_id_map is not None: score_list.append(probs[i].item())
label_name_list.append(self.class_id_map[i.item()]) if self.class_id_map is not None:
result = { label_name_list.append(self.class_id_map[i.item()])
"class_ids": clas_id_list, result = {
"scores": np.around( "class_ids": clas_id_list,
score_list, decimals=5).tolist(), "scores": np.around(
} score_list, decimals=5).tolist(),
if label_name_list is not None: }
result["label_names"] = label_name_list if label_name_list is not None:
result["label_names"] = label_name_list
all_results.append(result)
data["classification_res"] = result data["classification_res"] = all_results
return data return data
...@@ -12,33 +12,29 @@ class DetPostPro(BaseProcessor): ...@@ -12,33 +12,29 @@ class DetPostPro(BaseProcessor):
self.max_det_results = config["max_det_results"] self.max_det_results = config["max_det_results"]
def process(self, data): def process(self, data):
pred = data["pred"] np_boxes = data["pred"]["boxes"]
np_boxes = pred[list(pred.keys())[0]]
if reduce(lambda x, y: x * y, np_boxes.shape) >= 6: if reduce(lambda x, y: x * y, np_boxes.shape) >= 6:
keep_indexes = np_boxes[:, 1].argsort()[::-1][: keep_indexes = np_boxes[:, 1].argsort()[::-1][:
self.max_det_results] self.max_det_results]
# TODO(gaotingquan): only support bs==1
single_res = np_boxes[0] all_results = []
class_id = int(single_res[0]) for idx in keep_indexes:
score = single_res[1] single_res = np_boxes[idx]
bbox = single_res[2:] class_id = int(single_res[0])
if score > self.threshold: score = single_res[1]
bbox = single_res[2:]
if score < self.threshold:
continue
label_name = self.label_list[class_id] label_name = self.label_list[class_id]
results = { all_results.append({
"class_id": class_id, "class_id": class_id,
"score": score, "score": score,
"bbox": bbox, "bbox": bbox,
"label_name": label_name, "label_name": label_name
} })
data["detection_res"] = results data["detection_res"] = all_results
return data return data
logger.warning('[Detector] No object detected.') logger.warning('[Detector] No object detected.')
results = { data["detection_res"] = []
"class_id": None,
"score": None,
"bbox": None,
"label_name": None,
}
data["detection_res"] = results
return data return data
...@@ -55,10 +55,8 @@ class PaddlePredictor(BaseProcessor): ...@@ -55,10 +55,8 @@ class PaddlePredictor(BaseProcessor):
} }
else: else:
self.input_name_map = {} self.input_name_map = {}
if "from_model_names" in config and config["from_model_names"]:
self.output_name_map = config["from_model_names"] self.output_name_map = config["from_model_indexes"]
else:
self.output_name_map = {}
def process(self, data): def process(self, data):
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
...@@ -73,15 +71,12 @@ class PaddlePredictor(BaseProcessor): ...@@ -73,15 +71,12 @@ class PaddlePredictor(BaseProcessor):
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
for output_name in output_names: for output_name in output_names:
output = self.predictor.get_output_handle(output_name) output = self.predictor.get_output_handle(output_name)
model_output.append((output_name, output.copy_to_cpu())) model_output.append(output.copy_to_cpu())
if self.output_name_map: output_data = {}
output_data = {} for name in self.output_name_map:
for name in self.output_name_map: idx = self.output_name_map[name]
idx = self.output_name_map[name] output_data[name] = model_output[idx]
output_data[name] = model_output[idx][1]
else:
output_data = dict(model_output)
data["pred"] = output_data data["pred"] = output_data
return data return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册