提交 4d67dba0 编写于 作者: 工药叉's avatar 工药叉

fix some bug

上级 ce96b3a9
......@@ -78,7 +78,7 @@ mode: 指定模型结构。如果为“base”,使用论文baseline模型结
### 预测
执行以下命令得到模型的预测结果。
```
python infer.py --checkpoint_path="./chkpnt/" --use_gpu=True --image_path="data/val_dataset/set5/baby_GT.bmp"
python infer.py --checkpoint_path="./chkpnt/" --image_path="data/val_dataset/set5/baby_GT.bmp"
```
需要通过选项`--checkpoint_path`指定模型文件。并使用`--image_path`指定要进行预测的图片。
......
......@@ -18,7 +18,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('checkpoint_path', str, 'model', "Checkpoint save path.")
add_arg('image_path', str, 'data/val_dataset/set5/baby_GT.bmp', "Img data path.")
add_arg('show_img', bool, True, "show img or not")
add_arg('show_img', bool, False, "show img or not")
add_arg('only_reconstruct', bool, False, "If True, input image is seemed as subsampled image")
add_arg('scale_factor', int, 3, "scale factor")
......@@ -29,8 +29,8 @@ def reconstruct_img(args):
img_test = cv2.imread(args.image_path)
yuv_test = cv2.cvtColor(img_test, cv2.COLOR_BGR2YCrCb)
img_h, img_w, img_c = img_test.shape
cv2.imshow('raw image', img_test)
if args.show_img:
cv2.imshow('raw image', img_test)
if args.only_reconstruct == False:
# blur image and cubic interpolation
......@@ -55,16 +55,17 @@ def reconstruct_img(args):
result_img[result_img >255] = 255
gap_y = int((img_y.shape[0]-result_img.shape[2])/2)
gap_x = int((img_y.shape[1]-result_img.shape[3])/2)
cv2.imshow('input_channel y', img_y)
if args.show_img:
cv2.imshow('input_channel y', img_y)
cv2.imwrite(os.path.join(os.path.split(args.image_path)[0],'beforeSR_'+os.path.split(args.image_path)[1]), img_y)
img_y[gap_y: gap_y + result_img.shape[2],
gap_x: gap_x + result_img.shape[3]]=result_img
if args.show_img:
cv2.imshow('output_channel y', img_y)
cv2.waitKey(0)
cv2.imshow('output_channel y', img_y)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.destroyAllWindows()
cv2.imwrite(os.path.join(os.path.split(args.image_path)[0],'afterSR_'+os.path.split(args.image_path)[1]), img_y)
return img_y
if __name__ == "__main__":
......
......@@ -37,17 +37,17 @@ N2=32
def net(X, Y, model_struct):
# construct net
conv1 = fluid.layers.nn.conv2d(X, model_struct.n1, model_struct.f1, act='relu', name='conv1' ,
conv1 = fluid.layers.conv2d(X, model_struct.n1, model_struct.f1, act='relu', name='conv1' ,
param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),
name='conv1_w'),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),
name='conv1_b'))
conv2 = fluid.layers.nn.conv2d(conv1, model_struct.n2, model_struct.f2, act='relu', name='conv2' ,
conv2 = fluid.layers.conv2d(conv1, model_struct.n2, model_struct.f2, act='relu', name='conv2' ,
param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),
name='conv2_w'),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),
name='conv2_b'))
pred = fluid.layers.nn.conv2d(conv2, 1, model_struct.f3, name='pred',
pred = fluid.layers.conv2d(conv2, 1, model_struct.f3, name='pred',
param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),
name='pred_w'),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),
......@@ -94,7 +94,7 @@ def train(args):
fluid.framework.default_main_program(),
feed=feeder.feed(data),
fetch_list=[y_loss])
if batch_id == 0 or backprops_cnt % 100 ==0:
if batch_id == 0:
fluid.io.save_inference_model(args.checkpoint_path, ['image'], [y_predict], exe)
val_loss, val_psnr = validation()
print("%i\tEpoch: %d \tCur Cost : %f\t Val Cost: %f\t PSNR :%f" % (backprops_cnt, epoch, np.array(loss[0])[0], val_loss, val_psnr))
......
......@@ -50,8 +50,8 @@ def read_data(data_path, batch_size, ext, scale_factor, bia_size):
img_blur = cv2.GaussianBlur(img_patch, (5, 5), 0)
img_sumsample = cv2.resize(img_blur, (int(33/scale_factor), int(33/scale_factor)))
img_input = cv2.resize(img_blur, (33, 33), interpolation=cv2.INTER_CUBIC)
img_inputs.append(img_input)
img_gths.append(img_gth)
img_inputs.append([img_input])
img_gths.append([img_gth])
count_bt += 1
if count_bt % batch_size == 0:
yield [[np.array(img_inputs), np.array(img_gths)]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册