提交 764455a5 编写于 作者: F FlyingQianMM

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleX into develop_qh

...@@ -118,13 +118,15 @@ public class Visualize { ...@@ -118,13 +118,15 @@ public class Visualize {
public Mat draw(SegResult result, Mat visualizeMat, ImageBlob imageBlob, int cutoutClass) { public Mat draw(SegResult result, Mat visualizeMat, ImageBlob imageBlob, int cutoutClass) {
int new_h = (int)imageBlob.getNewImageSize()[2]; int new_h = (int)imageBlob.getNewImageSize()[2];
int new_w = (int)imageBlob.getNewImageSize()[3]; int new_w = (int)imageBlob.getNewImageSize()[3];
Mat mask = new Mat(new_h, new_w, CvType.CV_8UC(1)); Mat mask = new Mat(new_h, new_w, CvType.CV_32FC(1));
float[] scoreData = new float[new_h*new_w];
for (int h = 0; h < new_h; h++) { for (int h = 0; h < new_h; h++) {
for (int w = 0; w < new_w; w++){ for (int w = 0; w < new_w; w++){
mask.put(h , w, (1-result.getMask().getScoreData()[cutoutClass + h * new_h + w]) * 255); scoreData[new_h * h + w] = (1-result.getMask().getScoreData()[cutoutClass + h * new_h + w]) * 255;
} }
} }
mask.put(0,0, scoreData);
mask.convertTo(mask,CvType.CV_8UC(1));
ListIterator<Map.Entry<String, int[]>> reverseReshapeInfo = new ArrayList<Map.Entry<String, int[]>>(imageBlob.getReshapeInfo().entrySet()).listIterator(imageBlob.getReshapeInfo().size()); ListIterator<Map.Entry<String, int[]>> reverseReshapeInfo = new ArrayList<Map.Entry<String, int[]>>(imageBlob.getReshapeInfo().entrySet()).listIterator(imageBlob.getReshapeInfo().size());
while (reverseReshapeInfo.hasPrevious()) { while (reverseReshapeInfo.hasPrevious()) {
Map.Entry<String, int[]> entry = reverseReshapeInfo.previous(); Map.Entry<String, int[]> entry = reverseReshapeInfo.previous();
...@@ -135,10 +137,7 @@ public class Visualize { ...@@ -135,10 +137,7 @@ public class Visualize {
Size sz = new Size(entry.getValue()[0], entry.getValue()[1]); Size sz = new Size(entry.getValue()[0], entry.getValue()[1]);
Imgproc.resize(mask, mask, sz,0,0,Imgproc.INTER_LINEAR); Imgproc.resize(mask, mask, sz,0,0,Imgproc.INTER_LINEAR);
} }
Log.i(TAG, "postprocess operator: " + entry.getKey());
Log.i(TAG, "shape:: " + String.valueOf(mask.width()) + ","+ String.valueOf(mask.height()));
} }
Mat dst = new Mat(); Mat dst = new Mat();
List<Mat> listMat = Arrays.asList(visualizeMat, mask); List<Mat> listMat = Arrays.asList(visualizeMat, mask);
Core.merge(listMat, dst); Core.merge(listMat, dst);
......
...@@ -39,6 +39,9 @@ def split_seg_dataset(dataset_dir, val_percent, test_percent, save_dir): ...@@ -39,6 +39,9 @@ def split_seg_dataset(dataset_dir, val_percent, test_percent, save_dir):
anno_name = replace_ext(image_file, "PNG") anno_name = replace_ext(image_file, "PNG")
if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)): if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
image_anno_list.append([image_file, anno_name]) image_anno_list.append([image_file, anno_name])
else:
logging.error("The annotation file {} doesn't exist!".format(
anno_name))
if not osp.exists(osp.join(dataset_dir, "labels.txt")): if not osp.exists(osp.join(dataset_dir, "labels.txt")):
for image_anno in image_anno_list: for image_anno in image_anno_list:
......
...@@ -47,6 +47,9 @@ def split_voc_dataset(dataset_dir, val_percent, test_percent, save_dir): ...@@ -47,6 +47,9 @@ def split_voc_dataset(dataset_dir, val_percent, test_percent, save_dir):
cname = obj.find('name').text cname = obj.find('name').text
if not cname in label_list: if not cname in label_list:
label_list.append(cname) label_list.append(cname)
else:
logging.error("The annotation file {} doesn't exist!".format(
anno_name))
random.shuffle(image_anno_list) random.shuffle(image_anno_list)
image_num = len(image_anno_list) image_num = len(image_anno_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册