From 0258bd4373e64cc6c61c8a2a3734005ef970c238 Mon Sep 17 00:00:00 2001 From: panjiacheng Date: Thu, 1 Nov 2018 13:02:07 -0700 Subject: [PATCH] Prediction: cruise training feature refactoring. --- modules/prediction/proto/lane_graph.proto | 3 ++- .../prediction/mlp_train/common/trajectory.py | 16 +++++++++++++--- .../prediction/mlp_train/cruiseMLP_train.py | 2 +- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/modules/prediction/proto/lane_graph.proto b/modules/prediction/proto/lane_graph.proto index 73aa0587ca..0e8056f727 100644 --- a/modules/prediction/proto/lane_graph.proto +++ b/modules/prediction/proto/lane_graph.proto @@ -44,7 +44,7 @@ message NearbyObstacle { optional double l = 3; // relative to focus obstacle } -// next id = 19 +// next id = 20 message LaneSequence { optional int32 lane_sequence_id = 1; repeated LaneSegment lane_segment = 2; @@ -62,6 +62,7 @@ message LaneSequence { optional double probability = 6 [default = 0.0]; optional double acceleration = 7 [default = 0.0]; optional double time_to_lane_center = 16; + optional double time_to_lane_edge = 19; repeated apollo.common.PathPoint path_point = 8; enum BehaviorType { diff --git a/modules/tools/prediction/mlp_train/common/trajectory.py b/modules/tools/prediction/mlp_train/common/trajectory.py index 796d27c258..521f045029 100644 --- a/modules/tools/prediction/mlp_train/common/trajectory.py +++ b/modules/tools/prediction/mlp_train/common/trajectory.py @@ -203,6 +203,7 @@ class TrajectoryToSample(object): new_lane_id = None has_started_lane_change = False has_finished_lane_change = False + lane_change_start_time = None lane_change_finish_time = 10.0 # Go through all the subsequent features in this sequence @@ -222,12 +223,14 @@ class TrajectoryToSample(object): # If step into another lane, label lane change to be started. lane_id_j = trajectory[j].lane.lane_feature.lane_id if lane_id_j not in curr_lane_seq: - has_started_lane_change = True - lane_change_finish_time = 10.0 - if new_lane_id is None: + if has_started_lane_change = False: + has_started_lane_change = True + lane_change_start_time = time_span + lane_change_finish_time = 10.0 new_lane_id = lane_id_j else: has_started_lane_change = False + new_lane_id = None # If roughly get to the center of another lane, label lane change to be finished. left_bound = trajectory[j].lane.lane_feature.dist_to_left_boundary @@ -267,20 +270,24 @@ class TrajectoryToSample(object): # Obstacle is following the original lane but is never at lane-center: if lane_change_finish_time == 10.0: lane_sequence.label = 4 + lane_sequence.time_to_lane_edge = 10.0 lane_sequence.time_to_lane_center = 10.0 # Obstacle is following the original lane and moved to lane-center else: lane_sequence.label = 1 + lane_sequence.time_to_lane_edge = 10.0 lane_sequence.time_to_lane_center = lane_change_finish_time # Obs has stepped out of this lane within 6sec. else: lane_sequence.label = 0 + lane_sequence.time_to_lane_edge = lane_change_start_time lane_sequence.time_to_lane_center = 100.0 # The current lane is NOT obstacle's original lane. else: # Obstacle is following the original lane. if not has_started_lane_change: lane_sequence.label = -1 + lane_sequence.time_to_lane_edge = 100.0 lane_sequence.time_to_lane_center = 100.0 else: new_lane_id_is_in_this_lane_seq = False @@ -293,15 +300,18 @@ class TrajectoryToSample(object): # Obstacle has finished lane changing within 6 sec. if has_finished_lane_change: lane_sequence.label = 2 + lane_sequence.time_to_lane_edge = lane_change_start_time lane_sequence.time_to_lane_center = lane_change_finish_time # Obstacle started lane changing but haven't finished yet. else: lane_sequence.label = 3 + lane_sequence.time_to_lane_edge = lane_change_start_time lane_sequence.time_to_lane_center = 10.0 # Obstacle has changed to some other lane. else: lane_sequence.label = -1 + lane_sequence.time_to_lane_edge = 100.0 lane_sequence.time_to_lane_center = 100.0 return trajectory diff --git a/modules/tools/prediction/mlp_train/cruiseMLP_train.py b/modules/tools/prediction/mlp_train/cruiseMLP_train.py index b476c6be61..2b0ad91038 100644 --- a/modules/tools/prediction/mlp_train/cruiseMLP_train.py +++ b/modules/tools/prediction/mlp_train/cruiseMLP_train.py @@ -181,7 +181,7 @@ Custom defined loss function that lumps the loss of classification and of regression together. ''' def loss_fn(c_pred, r_pred, target): - loss_C = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([39.0]).cuda()) #nn.BCELoss() + loss_C = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([1.0]).cuda()) #nn.BCELoss() loss_R = nn.MSELoss() loss = loss_C(c_pred, target[:,0].view(target.shape[0],1)) #loss = 4 * loss_C(c_pred, target[:,0].view(target.shape[0],1)) + \ -- GitLab