From 96469e70cc22f6297f9b9d6813e62ca0aa79eec2 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Fri, 17 Aug 2018 22:12:01 +0800 Subject: [PATCH] Refine ce support for object detection (#1165) * update ce --- fluid/object_detection/_ce.py | 2 +- fluid/object_detection/train.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fluid/object_detection/_ce.py b/fluid/object_detection/_ce.py index f536e8a1..f90887c9 100644 --- a/fluid/object_detection/_ce.py +++ b/fluid/object_detection/_ce.py @@ -8,7 +8,7 @@ from kpi import CostKpi, DurationKpi, AccKpi #### NOTE kpi.py should shared in models in some way!!!! train_cost_kpi = CostKpi('train_cost', 0.02, 0, actived=True) -test_acc_kpi = AccKpi('test_acc', 0.01, 0, actived=True) +test_acc_kpi = AccKpi('test_acc', 0.01, 0, actived=False) train_speed_kpi = AccKpi('train_speed', 0.2, 0, actived=False) train_cost_card4_kpi = CostKpi('train_cost_card4', 0.02, 0, actived=True) test_acc_card4_kpi = AccKpi('test_acc_card4', 0.01, 0, actived=True) diff --git a/fluid/object_detection/train.py b/fluid/object_detection/train.py index 46af235f..a07983ba 100644 --- a/fluid/object_detection/train.py +++ b/fluid/object_detection/train.py @@ -65,7 +65,6 @@ def train(args, name='gt_label', shape=[1], dtype='int32', lod_level=1) difficult = fluid.layers.data( name='gt_difficult', shape=[1], dtype='int32', lod_level=1) - locs, confs, box, box_var = mobile_net(num_classes, image, image_shape) nmsed_out = fluid.layers.detection_output( locs, confs, box, box_var, nms_threshold=args.nms_threshold) @@ -126,6 +125,9 @@ def train(args, train_reader = paddle.batch( reader.train(data_args, train_file_list), batch_size=batch_size) else: + import random + random.seed(0) + np.random.seed(0) train_reader = paddle.batch( reader.train(data_args, train_file_list, False), batch_size=batch_size) test_reader = paddle.batch( @@ -166,8 +168,6 @@ def train(args, start_time = time.time() prev_start_time = start_time every_pass_loss = [] - iter = 0 - pass_duration = 0.0 for batch_id, data in enumerate(train_reader()): prev_start_time = start_time start_time = time.time() -- GitLab