提交 6f772ada 编写于 作者: W wanghaoshuang

Fix ce of icnet.

上级 386b9ce9
#!/bin/bash
# This file is only used for continuous evaluation.
rm -rf ./ck
mkdir ck
python train.py --use_gpu=True --checkpoint_path="./ck"; python eval.py --model_path="./ck/100" | python _ce.py
rm -rf *_factor.txt
python train.py --use_gpu=True 1> log
cat log | python _ce.py
......@@ -8,12 +8,10 @@ from kpi import CostKpi, DurationKpi, AccKpi
# NOTE kpi.py should shared in models in some way!!!!
train_cost_kpi = CostKpi('train_cost', 0.02, actived=True)
test_acc_kpi = AccKpi('test_acc', 0.005, actived=True)
train_duration_kpi = DurationKpi('train_duration', 0.06, actived=True)
tracking_kpis = [
train_cost_kpi,
test_acc_kpi,
train_duration_kpi,
]
......
......@@ -20,12 +20,12 @@ add_arg('use_gpu', bool, True, "Whether use GPU to test.")
def cal_mean_iou(wrong, correct):
sum = wrong + cerroct
sum = wrong + correct
true_num = (sum != 0).sum()
for i in len(sum):
for i in range(len(sum)):
if sum[i] == 0:
sum[i] = 1
return (cerroct.astype("float64") / sum).sum() / true_num
return (correct.astype("float64") / sum).sum() / true_num
def create_iou(predict, label, mask, num_classes, image_shape):
......
......@@ -184,7 +184,7 @@ def res_block(input, filter_num, padding=0, dilation=None, name=None):
tmp = conv(tmp, 1, 1, filter_num, 1, 1, name=name + "_1_1_increase")
tmp = bn(tmp, relu=False)
tmp = input + tmp
tmp = fluid.layers.relu(tmp, name=name + "_relu")
tmp = fluid.layers.relu(tmp)
return tmp
......@@ -227,7 +227,7 @@ def proj_block(input, filter_num, padding=0, dilation=None, stride=1,
tmp = conv(tmp, 1, 1, filter_num, 1, 1, name=name + "_1_1_increase")
tmp = bn(tmp, relu=False)
tmp = proj_bn + tmp
tmp = fluid.layers.relu(tmp, name=name + "_relu")
tmp = fluid.layers.relu(tmp)
return tmp
......
......@@ -130,7 +130,7 @@ def train(args):
sub124_loss = 0.
sys.stdout.flush()
if iter_id % CHECKPOINT_PERIOD == 0:
if iter_id % CHECKPOINT_PERIOD == 0 and args.checkpoint_path is not None:
dir_name = args.checkpoint_path + "/" + str(iter_id)
fluid.io.save_persistables(exe, dirname=dir_name)
print "Saved checkpoint: %s" % (dir_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册