From 6e70b94de3cb39fd4f532267a565db0fc626ea9e Mon Sep 17 00:00:00 2001 From: kechxu Date: Sat, 26 Jan 2019 14:30:27 -0800 Subject: [PATCH] Prediction: refactor offline mode, gflags, feature proto and data for learning --- modules/prediction/common/BUILD | 1 + modules/prediction/common/feature_output.cc | 12 +++-- modules/prediction/common/feature_output.h | 7 ++- .../prediction/common/feature_output_test.cc | 4 +- modules/prediction/common/message_process.cc | 4 +- .../common/prediction_system_gflags.cc | 8 +-- .../common/prediction_system_gflags.h | 3 +- modules/prediction/conf/prediction.conf | 3 +- .../obstacles/obstacles_container.cc | 6 +-- .../evaluator/vehicle/cruise_mlp_evaluator.cc | 50 +++++++++---------- .../vehicle/junction_mlp_evaluator.cc | 2 +- .../vehicle/lane_scanning_evaluator.cc | 2 +- .../evaluator/vehicle/mlp_evaluator.cc | 12 +++-- modules/prediction/prediction_component.cc | 25 ---------- 14 files changed, 61 insertions(+), 78 deletions(-) diff --git a/modules/prediction/common/BUILD b/modules/prediction/common/BUILD index 13969ab76f..b76fd91bb0 100644 --- a/modules/prediction/common/BUILD +++ b/modules/prediction/common/BUILD @@ -68,6 +68,7 @@ cc_library( "//modules/common/util", "//modules/prediction/common:prediction_gflags", "//modules/prediction/proto:offline_features_proto", + "//modules/prediction/proto:prediction_proto", ], ) diff --git a/modules/prediction/common/feature_output.cc b/modules/prediction/common/feature_output.cc index 5e3a053c2b..727a57be2f 100644 --- a/modules/prediction/common/feature_output.cc +++ b/modules/prediction/common/feature_output.cc @@ -26,15 +26,17 @@ namespace prediction { Features FeatureOutput::features_; ListDataForLearning FeatureOutput::list_data_for_learning_; +PredictionObstacles prediction_obstacles_; std::size_t FeatureOutput::idx_feature_ = 0; std::size_t FeatureOutput::idx_learning_ = 0; +std::size_t idx_prediction_obstacle_ = 0; void FeatureOutput::Close() { ADEBUG << "Close feature output"; - if (FLAGS_prediction_offline_mode) { - Write(); + if (FLAGS_prediction_offline_mode == 1) { + WriteFeatureProto(); } - if (FLAGS_prediction_offline_dataforlearning) { + if (FLAGS_prediction_offline_mode == 2) { WriteDataForLearning(); } Clear(); @@ -52,7 +54,7 @@ bool FeatureOutput::Ready() { return true; } -void FeatureOutput::Insert(const Feature& feature) { +void FeatureOutput::InsertFeatureProto(const Feature& feature) { features_.add_feature()->CopyFrom(feature); } @@ -70,7 +72,7 @@ void FeatureOutput::InsertDataForLearning( ADEBUG << "Insert [" << category << "] data for learning"; } -void FeatureOutput::Write() { +void FeatureOutput::WriteFeatureProto() { if (features_.feature_size() <= 0) { ADEBUG << "Skip writing empty feature."; } else { diff --git a/modules/prediction/common/feature_output.h b/modules/prediction/common/feature_output.h index ebcd759dbb..fc59a5050e 100644 --- a/modules/prediction/common/feature_output.h +++ b/modules/prediction/common/feature_output.h @@ -20,6 +20,7 @@ #include #include "modules/prediction/proto/offline_features.pb.h" +#include "modules/prediction/proto/prediction_obstacle.pb.h" namespace apollo { namespace prediction { @@ -51,7 +52,7 @@ class FeatureOutput { * @brief Insert a feature * @param A feature in proto */ - static void Insert(const Feature& feature); + static void InsertFeatureProto(const Feature& feature); /** * @brief Insert a data_for_learning @@ -64,7 +65,7 @@ class FeatureOutput { /** * @brief Write features to a file */ - static void Write(); + static void WriteFeatureProto(); /** * @brief Write DataForLearning features to a file @@ -88,6 +89,8 @@ class FeatureOutput { static std::size_t idx_feature_; static ListDataForLearning list_data_for_learning_; static std::size_t idx_learning_; + static PredictionObstacles prediction_obstacles_; + static std::size_t idx_prediction_obstacle_; }; } // namespace prediction diff --git a/modules/prediction/common/feature_output_test.cc b/modules/prediction/common/feature_output_test.cc index ca945751cf..052649fb5b 100644 --- a/modules/prediction/common/feature_output_test.cc +++ b/modules/prediction/common/feature_output_test.cc @@ -35,7 +35,7 @@ TEST_F(FeatureOutputTest, insertion) { Feature feature; for (int i = 0; i < 3; ++i) { Feature feature; - FeatureOutput::Insert(feature); + FeatureOutput::InsertFeatureProto(feature); } EXPECT_EQ(3, FeatureOutput::Size()); } @@ -44,7 +44,7 @@ TEST_F(FeatureOutputTest, clear) { Feature feature; for (int i = 0; i < 3; ++i) { Feature feature; - FeatureOutput::Insert(feature); + FeatureOutput::InsertFeatureProto(feature); } FeatureOutput::Clear(); EXPECT_EQ(0, FeatureOutput::Size()); diff --git a/modules/prediction/common/message_process.cc b/modules/prediction/common/message_process.cc index 9858efa777..bc202b1653 100644 --- a/modules/prediction/common/message_process.cc +++ b/modules/prediction/common/message_process.cc @@ -144,7 +144,7 @@ void MessageProcess::OnPerception( auto end_time6 = std::chrono::system_clock::now(); // Insert features to FeatureOutput for offline_mode - if (FLAGS_prediction_offline_mode) { + if (FLAGS_prediction_offline_mode == 1) { for (const int id : ptr_obstacles_container->curr_frame_predictable_obstacle_ids()) { Obstacle* obstacle_ptr = ptr_obstacles_container->GetObstacle(id); @@ -155,7 +155,7 @@ void MessageProcess::OnPerception( AERROR << "Obstacle [" << id << "] has no latest feature."; return; } - FeatureOutput::Insert(obstacle_ptr->latest_feature()); + FeatureOutput::InsertFeatureProto(obstacle_ptr->latest_feature()); ADEBUG << "Insert feature into feature output"; } // Not doing evaluation on offline mode diff --git a/modules/prediction/common/prediction_system_gflags.cc b/modules/prediction/common/prediction_system_gflags.cc index 0d98b951c3..ece8d8e398 100644 --- a/modules/prediction/common/prediction_system_gflags.cc +++ b/modules/prediction/common/prediction_system_gflags.cc @@ -47,9 +47,11 @@ DEFINE_string( "a list of bag files or directories for offline mode. The items need to be " "separated by colon ':'. If this value is not set, the prediction module " "will use the listen to published ros topic mode."); -DEFINE_bool(prediction_offline_mode, false, "Prediction offline mode"); -DEFINE_bool(prediction_offline_dataforlearning, false, "Whether to extract " - "the features for offline learning-models training."); +DEFINE_int32(prediction_offline_mode, 0, + "0: online mode, no dump file" + "1: dump feature proto to feature.x.bin" + "2: dump data for learning to datalearn.x.bin" + "3: dump predicted trajectory to predict_obstacles.x.bin"); // Bag replay timestamp gap DEFINE_double(replay_timestamp_gap, 10.0, diff --git a/modules/prediction/common/prediction_system_gflags.h b/modules/prediction/common/prediction_system_gflags.h index 8396eb2e47..7cbccc773b 100644 --- a/modules/prediction/common/prediction_system_gflags.h +++ b/modules/prediction/common/prediction_system_gflags.h @@ -31,8 +31,7 @@ DECLARE_bool(prediction_test_mode); DECLARE_double(prediction_test_duration); DECLARE_string(prediction_offline_bags); -DECLARE_bool(prediction_offline_mode); -DECLARE_bool(prediction_offline_dataforlearning); +DECLARE_int32(prediction_offline_mode); // Bag replay timestamp gap DECLARE_double(replay_timestamp_gap); diff --git a/modules/prediction/conf/prediction.conf b/modules/prediction/conf/prediction.conf index 4d8aaa6b5c..8956fa0bc7 100644 --- a/modules/prediction/conf/prediction.conf +++ b/modules/prediction/conf/prediction.conf @@ -4,8 +4,7 @@ --noadjust_velocity_by_obstacle_heading --noadjust_velocity_by_position_shift --noenable_kf_tracking ---noprediction_offline_mode ---noprediction_offline_dataforlearning +--prediction_offline_mode=0 --lane_change_dist=10.0 diff --git a/modules/prediction/container/obstacles/obstacles_container.cc b/modules/prediction/container/obstacles/obstacles_container.cc index de5a618ba3..06fda5d9b2 100644 --- a/modules/prediction/container/obstacles/obstacles_container.cc +++ b/modules/prediction/container/obstacles/obstacles_container.cc @@ -67,13 +67,13 @@ void ObstaclesContainer::Insert(const ::google::protobuf::Message& message) { << timestamp_ << "]."; return; } - if (FLAGS_prediction_offline_mode) { + if (FLAGS_prediction_offline_mode == 1) { if (std::fabs(timestamp - timestamp_) > FLAGS_replay_timestamp_gap || FeatureOutput::Size() > FLAGS_max_num_dump_feature) { - FeatureOutput::Write(); + FeatureOutput::WriteFeatureProto(); } } - if (FLAGS_prediction_offline_dataforlearning) { + if (FLAGS_prediction_offline_mode == 2) { if (std::fabs(timestamp - timestamp_) > FLAGS_replay_timestamp_gap || FeatureOutput::SizeOfDataForLearning() > FLAGS_max_num_dump_feature) { FeatureOutput::WriteDataForLearning(); diff --git a/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc b/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc index 68ea30a666..40d4496427 100644 --- a/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc +++ b/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc @@ -92,25 +92,31 @@ void CruiseMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) { continue; } - if (!FLAGS_prediction_offline_mode) { - Eigen::MatrixXf obs_feature_mat = VectorToMatrixXf(feature_values, 0, - OBSTACLE_FEATURE_SIZE); - Eigen::MatrixXf lane_feature_mat = VectorToMatrixXf(feature_values, - OBSTACLE_FEATURE_SIZE + INTERACTION_FEATURE_SIZE, - static_cast(feature_values.size()), SINGLE_LANE_FEATURE_SIZE, - LANE_POINTS_SIZE); - Eigen::MatrixXf model_output; - if (lane_sequence_ptr->vehicle_on_lane()) { - go_model_ptr_->Run({lane_feature_mat, obs_feature_mat}, &model_output); - } else { - cutin_model_ptr_->Run( - {lane_feature_mat, obs_feature_mat}, &model_output); - } - double probability = model_output(0, 0); - double finish_time = model_output(0, 1); - lane_sequence_ptr->set_probability(probability); - lane_sequence_ptr->set_time_to_lane_center(finish_time); + // Insert features to DataForLearning + if (FLAGS_prediction_offline_mode == 2) { + FeatureOutput::InsertDataForLearning( + *latest_feature_ptr, feature_values, "junction"); + ADEBUG << "Save extracted features for learning locally."; + return; // Skip Compute probability for offline mode + } + + Eigen::MatrixXf obs_feature_mat = VectorToMatrixXf(feature_values, 0, + OBSTACLE_FEATURE_SIZE); + Eigen::MatrixXf lane_feature_mat = VectorToMatrixXf(feature_values, + OBSTACLE_FEATURE_SIZE + INTERACTION_FEATURE_SIZE, + static_cast(feature_values.size()), SINGLE_LANE_FEATURE_SIZE, + LANE_POINTS_SIZE); + Eigen::MatrixXf model_output; + if (lane_sequence_ptr->vehicle_on_lane()) { + go_model_ptr_->Run({lane_feature_mat, obs_feature_mat}, &model_output); + } else { + cutin_model_ptr_->Run( + {lane_feature_mat, obs_feature_mat}, &model_output); } + double probability = model_output(0, 0); + double finish_time = model_output(0, 1); + lane_sequence_ptr->set_probability(probability); + lane_sequence_ptr->set_time_to_lane_center(finish_time); } } @@ -166,14 +172,6 @@ void CruiseMLPEvaluator::ExtractFeatureValues feature_values->insert(feature_values->end(), lane_feature_values.begin(), lane_feature_values.end()); - - // For offline training, write the extracted features into proto. - if (FLAGS_prediction_offline_mode) { - SaveOfflineFeatures(lane_sequence_ptr, *feature_values); - ADEBUG << "Save cruise mlp features for obstacle [" - << obstacle_ptr->id() << "] with dim [" - << feature_values->size() << "]"; - } } void CruiseMLPEvaluator::SetObstacleFeatureValues( diff --git a/modules/prediction/evaluator/vehicle/junction_mlp_evaluator.cc b/modules/prediction/evaluator/vehicle/junction_mlp_evaluator.cc index a55d4e8546..db1373742c 100644 --- a/modules/prediction/evaluator/vehicle/junction_mlp_evaluator.cc +++ b/modules/prediction/evaluator/vehicle/junction_mlp_evaluator.cc @@ -75,7 +75,7 @@ void JunctionMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) { ExtractFeatureValues(obstacle_ptr, &feature_values); // Insert features to DataForLearning - if (FLAGS_prediction_offline_dataforlearning) { + if (FLAGS_prediction_offline_mode == 2) { FeatureOutput::InsertDataForLearning( *latest_feature_ptr, feature_values, "junction"); ADEBUG << "Save extracted features for learning locally."; diff --git a/modules/prediction/evaluator/vehicle/lane_scanning_evaluator.cc b/modules/prediction/evaluator/vehicle/lane_scanning_evaluator.cc index 9de9b18571..4c5bf8d349 100644 --- a/modules/prediction/evaluator/vehicle/lane_scanning_evaluator.cc +++ b/modules/prediction/evaluator/vehicle/lane_scanning_evaluator.cc @@ -72,7 +72,7 @@ void LaneScanningEvaluator::Evaluate( std::vector feature_values; ExtractFeatures(obstacle_ptr, lane_graph_ptr, &feature_values); std::vector labels = {0.0}; - if (FLAGS_prediction_offline_dataforlearning) { + if (FLAGS_prediction_offline_mode == 2) { FeatureOutput::InsertDataForLearning(*latest_feature_ptr, feature_values, "cruise"); ADEBUG << "Save extracted features for learning locally."; diff --git a/modules/prediction/evaluator/vehicle/mlp_evaluator.cc b/modules/prediction/evaluator/vehicle/mlp_evaluator.cc index 26e1b487a3..a912d4b05b 100644 --- a/modules/prediction/evaluator/vehicle/mlp_evaluator.cc +++ b/modules/prediction/evaluator/vehicle/mlp_evaluator.cc @@ -80,6 +80,14 @@ void MLPEvaluator::Evaluate(Obstacle* obstacle_ptr) { CHECK(lane_sequence_ptr != nullptr); std::vector feature_values; ExtractFeatureValues(obstacle_ptr, lane_sequence_ptr, &feature_values); + // Insert features to DataForLearning + if (FLAGS_prediction_offline_mode == 2 && + !obstacle_ptr->IsNearJunction()) { + FeatureOutput::InsertDataForLearning( + *latest_feature_ptr, feature_values, "mlp"); + ADEBUG << "Save extracted features for learning locally."; + return; // Skip Compute probability for offline mode + } double probability = ComputeProbability(feature_values); double centripetal_acc_probability = @@ -124,10 +132,6 @@ void MLPEvaluator::ExtractFeatureValues(Obstacle* obstacle_ptr, obstacle_feature_values.end()); feature_values->insert(feature_values->end(), lane_feature_values.begin(), lane_feature_values.end()); - - if (FLAGS_prediction_offline_mode && !obstacle_ptr->IsNearJunction()) { - SaveOfflineFeatures(lane_sequence_ptr, *feature_values); - } } void MLPEvaluator::SaveOfflineFeatures( diff --git a/modules/prediction/prediction_component.cc b/modules/prediction/prediction_component.cc index 6a20f93447..72e9f1dc5e 100644 --- a/modules/prediction/prediction_component.cc +++ b/modules/prediction/prediction_component.cc @@ -83,31 +83,6 @@ bool PredictionComponent::Init() { prediction_writer_ = node_->CreateWriter(FLAGS_prediction_topic); - if (FLAGS_prediction_offline_mode) { - if (!FeatureOutput::Ready()) { - AERROR << "Feature output is not ready."; - return false; - } - if (FLAGS_prediction_offline_bags.empty()) { - return true; // use listen to ROS topic mode - } - std::vector inputs; - common::util::Split(FLAGS_prediction_offline_bags, ':', &inputs); - for (const auto& input : inputs) { - std::vector offline_bags; - GetRecordFileNames(boost::filesystem::path(input), &offline_bags); - std::sort(offline_bags.begin(), offline_bags.end()); - AINFO << "For input " << input << ", found " << offline_bags.size() - << " rosbags to process"; - for (std::size_t i = 0; i < offline_bags.size(); ++i) { - AINFO << "\tProcessing: [ " << i << " / " << offline_bags.size() - << " ]: " << offline_bags[i]; - MessageProcess::ProcessOfflineData(offline_bags[i]); - } - } - FeatureOutput::Close(); - return false; - } return true; } -- GitLab