未验证 提交 96469e70 编写于 作者: B Bai Yifan 提交者: GitHub

Refine ce support for object detection (#1165)

* update ce
上级 0db126d8
...@@ -8,7 +8,7 @@ from kpi import CostKpi, DurationKpi, AccKpi ...@@ -8,7 +8,7 @@ from kpi import CostKpi, DurationKpi, AccKpi
#### NOTE kpi.py should shared in models in some way!!!! #### NOTE kpi.py should shared in models in some way!!!!
train_cost_kpi = CostKpi('train_cost', 0.02, 0, actived=True) train_cost_kpi = CostKpi('train_cost', 0.02, 0, actived=True)
test_acc_kpi = AccKpi('test_acc', 0.01, 0, actived=True) test_acc_kpi = AccKpi('test_acc', 0.01, 0, actived=False)
train_speed_kpi = AccKpi('train_speed', 0.2, 0, actived=False) train_speed_kpi = AccKpi('train_speed', 0.2, 0, actived=False)
train_cost_card4_kpi = CostKpi('train_cost_card4', 0.02, 0, actived=True) train_cost_card4_kpi = CostKpi('train_cost_card4', 0.02, 0, actived=True)
test_acc_card4_kpi = AccKpi('test_acc_card4', 0.01, 0, actived=True) test_acc_card4_kpi = AccKpi('test_acc_card4', 0.01, 0, actived=True)
......
...@@ -65,7 +65,6 @@ def train(args, ...@@ -65,7 +65,6 @@ def train(args,
name='gt_label', shape=[1], dtype='int32', lod_level=1) name='gt_label', shape=[1], dtype='int32', lod_level=1)
difficult = fluid.layers.data( difficult = fluid.layers.data(
name='gt_difficult', shape=[1], dtype='int32', lod_level=1) name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape) locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output( nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=args.nms_threshold) locs, confs, box, box_var, nms_threshold=args.nms_threshold)
...@@ -126,6 +125,9 @@ def train(args, ...@@ -126,6 +125,9 @@ def train(args,
train_reader = paddle.batch( train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size) reader.train(data_args, train_file_list), batch_size=batch_size)
else: else:
import random
random.seed(0)
np.random.seed(0)
train_reader = paddle.batch( train_reader = paddle.batch(
reader.train(data_args, train_file_list, False), batch_size=batch_size) reader.train(data_args, train_file_list, False), batch_size=batch_size)
test_reader = paddle.batch( test_reader = paddle.batch(
...@@ -166,8 +168,6 @@ def train(args, ...@@ -166,8 +168,6 @@ def train(args,
start_time = time.time() start_time = time.time()
prev_start_time = start_time prev_start_time = start_time
every_pass_loss = [] every_pass_loss = []
iter = 0
pass_duration = 0.0
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time prev_start_time = start_time
start_time = time.time() start_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册