未验证 提交 1f1d80d5 编写于 作者: W whs 提交者: GitHub

1. Fix load checkpoint (#1971)

2. Refine readme
上级 08b8c05d
运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
## 代码结构
```
......@@ -68,7 +66,7 @@ Iter[0]; train loss: 2.338; sub4_loss: 3.367; sub24_loss: 4.120; sub124_loss: 0.
### 测试
执行以下命令在`Cityscape`测试数据集上进行测试:
```
python eval.py --model_path="./model/" --use_gpu=True
python eval.py --model_path="./cnkpnt/100" --use_gpu=True
```
需要通过选项`--model_path`指定模型文件。
测试脚本的输出的评估指标为[mean IoU]()。
......@@ -77,7 +75,7 @@ python eval.py --model_path="./model/" --use_gpu=True
执行以下命令对指定的数据进行预测:
```
python infer.py \
--model_path="./model" \
--model_path="./cnkpnt/100" \
--images_path="./data/cityscape/" \
--images_list="./data/cityscape/infer.list"
```
......
......@@ -122,6 +122,7 @@ def infer(args):
fetch_list=[predict])
cv2.imwrite(args.out_path + "/" + filename + "_result.png",
color(result[0]))
print("Saved images into: %s" % args.out_path)
def main():
......
......@@ -96,8 +96,11 @@ def train(args):
if args.init_model is not None:
print("load model from: %s" % args.init_model)
sys.stdout.flush()
fluid.io.load_params(exe, args.init_model)
def if_exist(var):
return os.path.exists(os.path.join(args.init_model, var.name))
fluid.io.load_vars(exe, args.init_model, predicate=if_exist)
iter_id = 0
t_loss = 0.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册