diff --git a/modules/prediction/common/feature_output.cc b/modules/prediction/common/feature_output.cc index 3f2d6150b6fc93986e69804bf72a990e9b6adff8..509f97908141feb822693c9bfaa499a1bfb3cd45 100644 --- a/modules/prediction/common/feature_output.cc +++ b/modules/prediction/common/feature_output.cc @@ -26,6 +26,7 @@ namespace apollo { namespace prediction { using apollo::common::util::StrCat; +using apollo::common::TrajectoryPoint; Features FeatureOutput::features_; ListDataForLearning FeatureOutput::list_data_for_learning_; @@ -132,7 +133,8 @@ void FeatureOutput::InsertFrameEnv(const FrameEnv& frame_env) { void FeatureOutput::InsertDataForTuning( const Feature& feature, const std::vector& feature_values, - const std::string& category, const LaneSequence& lane_sequence) { + const std::string& category, const LaneSequence& lane_sequence, + const std::vector& adc_trajectory) { DataForTuning* data_for_tuning = list_data_for_tuning_.add_data_for_tuning(); data_for_tuning->set_id(feature.id()); data_for_tuning->set_timestamp(feature.timestamp()); @@ -142,6 +144,9 @@ void FeatureOutput::InsertDataForTuning( ADEBUG << "Insert [" << category << "] data for tuning with size = " << feature_values.size(); data_for_tuning->set_lane_sequence_id(lane_sequence.lane_sequence_id()); + for (const auto& adc_traj_point : adc_trajectory) { + data_for_tuning->add_adc_trajectory_point()->CopyFrom(adc_traj_point); + } } void FeatureOutput::WriteFeatureProto() { diff --git a/modules/prediction/common/feature_output.h b/modules/prediction/common/feature_output.h index 6d0b28b70f826bf43cb16ebfbbc42e8536c95695..7aeaf8de1050b3b623050c52dbf9d9ceed19bda0 100644 --- a/modules/prediction/common/feature_output.h +++ b/modules/prediction/common/feature_output.h @@ -88,11 +88,12 @@ class FeatureOutput { * @param values for tuning * @param category of the data * @param lane sequence + * @param adc trajectory */ static void InsertDataForTuning(const Feature& feature, - const std::vector& feature_values, - const std::string& category, - const LaneSequence& lane_sequence); + const std::vector& feature_values, + const std::string& category, const LaneSequence& lane_sequence, + const std::vector& adc_trajectory); /** * @brief Write features to a file diff --git a/modules/prediction/predictor/interaction/interaction_predictor.cc b/modules/prediction/predictor/interaction/interaction_predictor.cc index a8467d0f45a8342fa96ad4a4d916fd99159a134e..c975a9ccd1edb31eb741e5df57b0a9741790f46f 100644 --- a/modules/prediction/predictor/interaction/interaction_predictor.cc +++ b/modules/prediction/predictor/interaction/interaction_predictor.cc @@ -261,7 +261,8 @@ double InteractionPredictor::ComputeTrajectoryCost(const Obstacle& obstacle, std::vector cost_values = {lon_acc_cost, centri_acc_cost, collision_cost}; FeatureOutput::InsertDataForTuning( - obstacle.latest_feature(), cost_values, "interaction", lane_sequence); + obstacle.latest_feature(), cost_values, "interaction", lane_sequence, + adc_trajectory_); } return total_cost; diff --git a/modules/prediction/proto/BUILD b/modules/prediction/proto/BUILD index e962097e665f14f31b3b0998b7261021031a856a..db71b26745ef3764047e85399280b24446cf44c6 100644 --- a/modules/prediction/proto/BUILD +++ b/modules/prediction/proto/BUILD @@ -91,6 +91,7 @@ proto_library( "offline_features.proto", ], deps = [ + "//modules/common/proto:pnc_point_proto_lib", "//modules/prediction/proto:feature_proto_lib", "//modules/prediction/proto:prediction_proto_lib", ], diff --git a/modules/prediction/proto/offline_features.proto b/modules/prediction/proto/offline_features.proto index 325c9c654cd2e81071a27ca80ded14451be37667..83390af6da91eb42f095185623ab4251d264f38f 100644 --- a/modules/prediction/proto/offline_features.proto +++ b/modules/prediction/proto/offline_features.proto @@ -2,6 +2,7 @@ syntax = "proto2"; package apollo.prediction; +import "modules/common/proto/pnc_point.proto"; import "modules/prediction/proto/feature.proto"; message CNNFeatures { @@ -58,6 +59,8 @@ message DataForTuning { repeated double real_cost_value = 5; // The lane sequence id if associated with a lane optional int32 lane_sequence_id = 6; + // The associated adc trajectory + repeated apollo.common.TrajectoryPoint adc_trajectory_point = 7; } message ListDataForTuning {