提交 886a6fea 编写于 作者: W wuzewu

Fix infer bug

上级 5ae538c0
...@@ -15,14 +15,15 @@ cfg = getattr(config, 'cfg') ...@@ -15,14 +15,15 @@ cfg = getattr(config, 'cfg')
cluster = Cluster() cluster = Cluster()
# 预测数据集类 # 预测数据集类
class TestDataSet(): class TestDataSet():
def __init__(self): def __init__(self):
self.data_dir = cfg.data_dir self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list() self.data_list = self.get_data_list()
self.data_num = len(self.data_list) self.data_num = len(self.data_list)
def get_data_list(self): def get_data_list(self):
# 获取预测图像路径列表 # 获取预测图像路径列表
data_list = [] data_list = []
...@@ -41,7 +42,7 @@ class TestDataSet(): ...@@ -41,7 +42,7 @@ class TestDataSet():
h, w = img.shape[:2] h, w = img.shape[:2]
h_new, w_new = cfg.input_size h_new, w_new = cfg.input_size
img = np.pad(img, ((0, h_new - h), (0, w_new - w), (0, 0)), 'edge') img = np.pad(img, ((0, h_new - h), (0, w_new - w), (0, 0)), 'edge')
img = img.astype(np.float32)/255.0 img = img.astype(np.float32) / 255.0
img = img.transpose((2, 0, 1)) img = img.transpose((2, 0, 1))
img = np.expand_dims(img, axis=0) img = np.expand_dims(img, axis=0)
return img return img
...@@ -51,30 +52,33 @@ class TestDataSet(): ...@@ -51,30 +52,33 @@ class TestDataSet():
img_path = self.data_list[index] img_path = self.data_list[index]
img = np.array(PILImage.open(img_path)) img = np.array(PILImage.open(img_path))
if img is None: if img is None:
return img, img,img_path, None return img, img, img_path, None
img_name = img_path.split(os.sep)[-1] img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'') name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2] img_shape = img.shape[:2]
img_process = self.preprocess(img) img_process = self.preprocess(img)
return img_process, name_prefix, img_shape return img_process, name_prefix, img_shape
def get_model(main_prog, startup_prog): def get_model(main_prog, startup_prog):
img_shape = [3, cfg.input_size[0], cfg.input_size[1]] img_shape = [3, cfg.input_size[0], cfg.input_size[1]]
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
input = fluid.layers.data(name='image', shape=img_shape, dtype='float32') input = fluid.layers.data(
name='image', shape=img_shape, dtype='float32')
output = SpatialEmbeddings(input) output = SpatialEmbeddings(input)
return input, output return input, output
def infer(): def infer():
if not os.path.exists(cfg.vis_dir): if not os.path.exists(cfg.vis_dir):
os.makedirs(cfg.vis_dir) os.makedirs(cfg.vis_dir)
startup_prog = fluid.Program() startup_prog = fluid.Program()
test_prog = fluid.Program() test_prog = fluid.Program()
input, output = get_model(test_prog, startup_prog) input, output = get_model(test_prog, startup_prog)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
...@@ -82,11 +86,17 @@ def infer(): ...@@ -82,11 +86,17 @@ def infer():
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
if not os.path.exists(cfg.model_path):
raise RuntimeError('No pre-trained model found under path {}'.format(
cfg.model_path))
# 加载预测模型 # 加载预测模型
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(cfg.model_path, var.name)) return os.path.exists(os.path.join(cfg.model_path, var.name))
fluid.io.load_vars(exe, cfg.model_path, main_program=test_prog, predicate=if_exist)
fluid.io.load_vars(
exe, cfg.model_path, main_program=test_prog, predicate=if_exist)
#加载预测数据集 #加载预测数据集
test_dataset = TestDataSet() test_dataset = TestDataSet()
data_num = test_dataset.data_num data_num = test_dataset.data_num
...@@ -97,9 +107,10 @@ def infer(): ...@@ -97,9 +107,10 @@ def infer():
if image is None: if image is None:
print(im_name, 'is None') print(im_name, 'is None')
continue continue
# 预测 # 预测
outputs = exe.run(program=test_prog, feed={'image': image}, fetch_list=output) outputs = exe.run(
program=test_prog, feed={'image': image}, fetch_list=output)
instance_map, predictions = cluster.cluster(outputs[0][0], n_sigma=cfg.n_sigma, \ instance_map, predictions = cluster.cluster(outputs[0][0], n_sigma=cfg.n_sigma, \
min_pixel=cfg.min_pixel, threshold=cfg.threshold) min_pixel=cfg.min_pixel, threshold=cfg.threshold)
...@@ -109,14 +120,14 @@ def infer(): ...@@ -109,14 +120,14 @@ def infer():
output_im = PILImage.fromarray(np.asarray(instance_map, dtype=np.uint8)) output_im = PILImage.fromarray(np.asarray(instance_map, dtype=np.uint8))
palette = get_palette(len(predictions) + 1) palette = get_palette(len(predictions) + 1)
output_im.putpalette(palette) output_im.putpalette(palette)
result_path = os.path.join(cfg.vis_dir, im_name+'.png') result_path = os.path.join(cfg.vis_dir, im_name + '.png')
output_im.save(result_path) output_im.save(result_path)
if (idx + 1) % 100 == 0: if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1)) print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1)) print('%d processd done' % (idx + 1))
return 0 return 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册