diff --git a/contrib/SpatialEmbeddings/infer.py b/contrib/SpatialEmbeddings/infer.py index c777f2d5e4cac154dc778ec0303b57d58f91e024..9b6c4db891ef54e56fc8881c8941c11dc07e46f9 100644 --- a/contrib/SpatialEmbeddings/infer.py +++ b/contrib/SpatialEmbeddings/infer.py @@ -15,14 +15,15 @@ cfg = getattr(config, 'cfg') cluster = Cluster() + # 预测数据集类 class TestDataSet(): def __init__(self): - self.data_dir = cfg.data_dir + self.data_dir = cfg.data_dir self.data_list_file = cfg.data_list_file self.data_list = self.get_data_list() self.data_num = len(self.data_list) - + def get_data_list(self): # 获取预测图像路径列表 data_list = [] @@ -41,7 +42,7 @@ class TestDataSet(): h, w = img.shape[:2] h_new, w_new = cfg.input_size img = np.pad(img, ((0, h_new - h), (0, w_new - w), (0, 0)), 'edge') - img = img.astype(np.float32)/255.0 + img = img.astype(np.float32) / 255.0 img = img.transpose((2, 0, 1)) img = np.expand_dims(img, axis=0) return img @@ -51,30 +52,33 @@ class TestDataSet(): img_path = self.data_list[index] img = np.array(PILImage.open(img_path)) if img is None: - return img, img,img_path, None + return img, img, img_path, None img_name = img_path.split(os.sep)[-1] - name_prefix = img_name.replace('.'+img_name.split('.')[-1],'') + name_prefix = img_name.replace('.' + img_name.split('.')[-1], '') img_shape = img.shape[:2] img_process = self.preprocess(img) return img_process, name_prefix, img_shape + def get_model(main_prog, startup_prog): img_shape = [3, cfg.input_size[0], cfg.input_size[1]] with fluid.program_guard(main_prog, startup_prog): with fluid.unique_name.guard(): - input = fluid.layers.data(name='image', shape=img_shape, dtype='float32') + input = fluid.layers.data( + name='image', shape=img_shape, dtype='float32') output = SpatialEmbeddings(input) return input, output + def infer(): if not os.path.exists(cfg.vis_dir): os.makedirs(cfg.vis_dir) startup_prog = fluid.Program() test_prog = fluid.Program() - + input, output = get_model(test_prog, startup_prog) test_prog = test_prog.clone(for_test=True) @@ -82,11 +86,17 @@ def infer(): exe = fluid.Executor(place) exe.run(startup_prog) + if not os.path.exists(cfg.model_path): + raise RuntimeError('No pre-trained model found under path {}'.format( + cfg.model_path)) + # 加载预测模型 def if_exist(var): return os.path.exists(os.path.join(cfg.model_path, var.name)) - fluid.io.load_vars(exe, cfg.model_path, main_program=test_prog, predicate=if_exist) - + + fluid.io.load_vars( + exe, cfg.model_path, main_program=test_prog, predicate=if_exist) + #加载预测数据集 test_dataset = TestDataSet() data_num = test_dataset.data_num @@ -97,9 +107,10 @@ def infer(): if image is None: print(im_name, 'is None') continue - + # 预测 - outputs = exe.run(program=test_prog, feed={'image': image}, fetch_list=output) + outputs = exe.run( + program=test_prog, feed={'image': image}, fetch_list=output) instance_map, predictions = cluster.cluster(outputs[0][0], n_sigma=cfg.n_sigma, \ min_pixel=cfg.min_pixel, threshold=cfg.threshold) @@ -109,14 +120,14 @@ def infer(): output_im = PILImage.fromarray(np.asarray(instance_map, dtype=np.uint8)) palette = get_palette(len(predictions) + 1) output_im.putpalette(palette) - result_path = os.path.join(cfg.vis_dir, im_name+'.png') + result_path = os.path.join(cfg.vis_dir, im_name + '.png') output_im.save(result_path) if (idx + 1) % 100 == 0: print('%d processd' % (idx + 1)) - - print('%d processd done' % (idx + 1)) - + + print('%d processd done' % (idx + 1)) + return 0