提交 4d67dba0 编写于 作者: 工药叉's avatar 工药叉

fix some bug

上级 ce96b3a9
...@@ -78,7 +78,7 @@ mode: 指定模型结构。如果为“base”,使用论文baseline模型结 ...@@ -78,7 +78,7 @@ mode: 指定模型结构。如果为“base”,使用论文baseline模型结
### 预测 ### 预测
执行以下命令得到模型的预测结果。 执行以下命令得到模型的预测结果。
``` ```
python infer.py --checkpoint_path="./chkpnt/" --use_gpu=True --image_path="data/val_dataset/set5/baby_GT.bmp" python infer.py --checkpoint_path="./chkpnt/" --image_path="data/val_dataset/set5/baby_GT.bmp"
``` ```
需要通过选项`--checkpoint_path`指定模型文件。并使用`--image_path`指定要进行预测的图片。 需要通过选项`--checkpoint_path`指定模型文件。并使用`--image_path`指定要进行预测的图片。
......
...@@ -18,7 +18,7 @@ add_arg = functools.partial(add_arguments, argparser=parser) ...@@ -18,7 +18,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('checkpoint_path', str, 'model', "Checkpoint save path.") add_arg('checkpoint_path', str, 'model', "Checkpoint save path.")
add_arg('image_path', str, 'data/val_dataset/set5/baby_GT.bmp', "Img data path.") add_arg('image_path', str, 'data/val_dataset/set5/baby_GT.bmp', "Img data path.")
add_arg('show_img', bool, True, "show img or not") add_arg('show_img', bool, False, "show img or not")
add_arg('only_reconstruct', bool, False, "If True, input image is seemed as subsampled image") add_arg('only_reconstruct', bool, False, "If True, input image is seemed as subsampled image")
add_arg('scale_factor', int, 3, "scale factor") add_arg('scale_factor', int, 3, "scale factor")
...@@ -29,8 +29,8 @@ def reconstruct_img(args): ...@@ -29,8 +29,8 @@ def reconstruct_img(args):
img_test = cv2.imread(args.image_path) img_test = cv2.imread(args.image_path)
yuv_test = cv2.cvtColor(img_test, cv2.COLOR_BGR2YCrCb) yuv_test = cv2.cvtColor(img_test, cv2.COLOR_BGR2YCrCb)
img_h, img_w, img_c = img_test.shape img_h, img_w, img_c = img_test.shape
if args.show_img:
cv2.imshow('raw image', img_test) cv2.imshow('raw image', img_test)
if args.only_reconstruct == False: if args.only_reconstruct == False:
# blur image and cubic interpolation # blur image and cubic interpolation
...@@ -55,16 +55,17 @@ def reconstruct_img(args): ...@@ -55,16 +55,17 @@ def reconstruct_img(args):
result_img[result_img >255] = 255 result_img[result_img >255] = 255
gap_y = int((img_y.shape[0]-result_img.shape[2])/2) gap_y = int((img_y.shape[0]-result_img.shape[2])/2)
gap_x = int((img_y.shape[1]-result_img.shape[3])/2) gap_x = int((img_y.shape[1]-result_img.shape[3])/2)
if args.show_img:
cv2.imshow('input_channel y', img_y) cv2.imshow('input_channel y', img_y)
cv2.imwrite(os.path.join(os.path.split(args.image_path)[0],'beforeSR_'+os.path.split(args.image_path)[1]), img_y)
img_y[gap_y: gap_y + result_img.shape[2], img_y[gap_y: gap_y + result_img.shape[2],
gap_x: gap_x + result_img.shape[3]]=result_img gap_x: gap_x + result_img.shape[3]]=result_img
if args.show_img:
cv2.imshow('output_channel y', img_y)
cv2.waitKey(0)
cv2.imshow('output_channel y', img_y) cv2.destroyAllWindows()
cv2.waitKey(0) cv2.imwrite(os.path.join(os.path.split(args.image_path)[0],'afterSR_'+os.path.split(args.image_path)[1]), img_y)
cv2.destroyAllWindows()
return img_y return img_y
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -37,17 +37,17 @@ N2=32 ...@@ -37,17 +37,17 @@ N2=32
def net(X, Y, model_struct): def net(X, Y, model_struct):
# construct net # construct net
conv1 = fluid.layers.nn.conv2d(X, model_struct.n1, model_struct.f1, act='relu', name='conv1' , conv1 = fluid.layers.conv2d(X, model_struct.n1, model_struct.f1, act='relu', name='conv1' ,
param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001), param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),
name='conv1_w'), name='conv1_w'),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.), bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),
name='conv1_b')) name='conv1_b'))
conv2 = fluid.layers.nn.conv2d(conv1, model_struct.n2, model_struct.f2, act='relu', name='conv2' , conv2 = fluid.layers.conv2d(conv1, model_struct.n2, model_struct.f2, act='relu', name='conv2' ,
param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001), param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),
name='conv2_w'), name='conv2_w'),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.), bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),
name='conv2_b')) name='conv2_b'))
pred = fluid.layers.nn.conv2d(conv2, 1, model_struct.f3, name='pred', pred = fluid.layers.conv2d(conv2, 1, model_struct.f3, name='pred',
param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001), param_attr= fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001),
name='pred_w'), name='pred_w'),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.), bias_attr=fluid.ParamAttr(initializer=fluid.initializer.ConstantInitializer(value=0.),
...@@ -94,7 +94,7 @@ def train(args): ...@@ -94,7 +94,7 @@ def train(args):
fluid.framework.default_main_program(), fluid.framework.default_main_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[y_loss]) fetch_list=[y_loss])
if batch_id == 0 or backprops_cnt % 100 ==0: if batch_id == 0:
fluid.io.save_inference_model(args.checkpoint_path, ['image'], [y_predict], exe) fluid.io.save_inference_model(args.checkpoint_path, ['image'], [y_predict], exe)
val_loss, val_psnr = validation() val_loss, val_psnr = validation()
print("%i\tEpoch: %d \tCur Cost : %f\t Val Cost: %f\t PSNR :%f" % (backprops_cnt, epoch, np.array(loss[0])[0], val_loss, val_psnr)) print("%i\tEpoch: %d \tCur Cost : %f\t Val Cost: %f\t PSNR :%f" % (backprops_cnt, epoch, np.array(loss[0])[0], val_loss, val_psnr))
......
...@@ -50,8 +50,8 @@ def read_data(data_path, batch_size, ext, scale_factor, bia_size): ...@@ -50,8 +50,8 @@ def read_data(data_path, batch_size, ext, scale_factor, bia_size):
img_blur = cv2.GaussianBlur(img_patch, (5, 5), 0) img_blur = cv2.GaussianBlur(img_patch, (5, 5), 0)
img_sumsample = cv2.resize(img_blur, (int(33/scale_factor), int(33/scale_factor))) img_sumsample = cv2.resize(img_blur, (int(33/scale_factor), int(33/scale_factor)))
img_input = cv2.resize(img_blur, (33, 33), interpolation=cv2.INTER_CUBIC) img_input = cv2.resize(img_blur, (33, 33), interpolation=cv2.INTER_CUBIC)
img_inputs.append(img_input) img_inputs.append([img_input])
img_gths.append(img_gth) img_gths.append([img_gth])
count_bt += 1 count_bt += 1
if count_bt % batch_size == 0: if count_bt % batch_size == 0:
yield [[np.array(img_inputs), np.array(img_gths)]] yield [[np.array(img_inputs), np.array(img_gths)]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册