提交 886a6fea 编写于 作者: W wuzewu

Fix infer bug

上级 5ae538c0
......@@ -15,6 +15,7 @@ cfg = getattr(config, 'cfg')
cluster = Cluster()
# 预测数据集类
class TestDataSet():
def __init__(self):
......@@ -41,7 +42,7 @@ class TestDataSet():
h, w = img.shape[:2]
h_new, w_new = cfg.input_size
img = np.pad(img, ((0, h_new - h), (0, w_new - w), (0, 0)), 'edge')
img = img.astype(np.float32)/255.0
img = img.astype(np.float32) / 255.0
img = img.transpose((2, 0, 1))
img = np.expand_dims(img, axis=0)
return img
......@@ -51,23 +52,26 @@ class TestDataSet():
img_path = self.data_list[index]
img = np.array(PILImage.open(img_path))
if img is None:
return img, img,img_path, None
return img, img, img_path, None
img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'')
name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2]
img_process = self.preprocess(img)
return img_process, name_prefix, img_shape
def get_model(main_prog, startup_prog):
img_shape = [3, cfg.input_size[0], cfg.input_size[1]]
with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard():
input = fluid.layers.data(name='image', shape=img_shape, dtype='float32')
input = fluid.layers.data(
name='image', shape=img_shape, dtype='float32')
output = SpatialEmbeddings(input)
return input, output
def infer():
if not os.path.exists(cfg.vis_dir):
os.makedirs(cfg.vis_dir)
......@@ -82,10 +86,16 @@ def infer():
exe = fluid.Executor(place)
exe.run(startup_prog)
if not os.path.exists(cfg.model_path):
raise RuntimeError('No pre-trained model found under path {}'.format(
cfg.model_path))
# 加载预测模型
def if_exist(var):
return os.path.exists(os.path.join(cfg.model_path, var.name))
fluid.io.load_vars(exe, cfg.model_path, main_program=test_prog, predicate=if_exist)
fluid.io.load_vars(
exe, cfg.model_path, main_program=test_prog, predicate=if_exist)
#加载预测数据集
test_dataset = TestDataSet()
......@@ -99,7 +109,8 @@ def infer():
continue
# 预测
outputs = exe.run(program=test_prog, feed={'image': image}, fetch_list=output)
outputs = exe.run(
program=test_prog, feed={'image': image}, fetch_list=output)
instance_map, predictions = cluster.cluster(outputs[0][0], n_sigma=cfg.n_sigma, \
min_pixel=cfg.min_pixel, threshold=cfg.threshold)
......@@ -109,7 +120,7 @@ def infer():
output_im = PILImage.fromarray(np.asarray(instance_map, dtype=np.uint8))
palette = get_palette(len(predictions) + 1)
output_im.putpalette(palette)
result_path = os.path.join(cfg.vis_dir, im_name+'.png')
result_path = os.path.join(cfg.vis_dir, im_name + '.png')
output_im.save(result_path)
if (idx + 1) % 100 == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册