diff --git a/modules/prediction/common/prediction_gflags.cc b/modules/prediction/common/prediction_gflags.cc index 034cb29c1204f331ce80253cde90bcf30f8e02ab..ef8be8cdb1f97cae3fc490d9e4f91168aab420a7 100644 --- a/modules/prediction/common/prediction_gflags.cc +++ b/modules/prediction/common/prediction_gflags.cc @@ -118,6 +118,9 @@ DEFINE_string(evaluator_vehicle_mlp_file, DEFINE_string(evaluator_vehicle_rnn_file, "/apollo/modules/prediction/data/rnn_vehicle_model.bin", "rnn model file for vehicle evaluator"); +DEFINE_string(evaluator_vehicle_cruise_mlp_file, + "/apollo/modules/prediction/data/cruise_mlp_vehicle_model.bin", + "Vehicle cruise MLP model file"); DEFINE_string(evaluator_vehicle_junction_mlp_file, "/apollo/modules/prediction/data/junction_mlp_vehicle_model.bin", "Vehicle junction MLP model file"); diff --git a/modules/prediction/common/prediction_gflags.h b/modules/prediction/common/prediction_gflags.h index 1fbf682fddab5c1231961bb4b13a914572d5c50f..5a1afc0c9257df4fce7d29d500db42edafd0864d 100644 --- a/modules/prediction/common/prediction_gflags.h +++ b/modules/prediction/common/prediction_gflags.h @@ -81,6 +81,7 @@ DECLARE_double(prediction_pedestrian_total_time); DECLARE_double(still_speed); DECLARE_string(evaluator_vehicle_mlp_file); DECLARE_string(evaluator_vehicle_rnn_file); +DECLARE_string(evaluator_vehicle_cruise_mlp_file); DECLARE_string(evaluator_vehicle_junction_mlp_file); DECLARE_int32(max_num_obstacles); DECLARE_double(valid_position_diff_threshold); diff --git a/modules/prediction/common/prediction_util_test.cc b/modules/prediction/common/prediction_util_test.cc index 73de49e8ff4a48d451bc867b6d6ff2d4333a5528..94293e62eec7dc571da18dc3abe324cccb1a82d7 100644 --- a/modules/prediction/common/prediction_util_test.cc +++ b/modules/prediction/common/prediction_util_test.cc @@ -62,11 +62,15 @@ TEST(PredictionUtilTest, solve_cubic_polynomial_and_evaluate) { double param = 5.0; auto coefs = ComputePolynomial<3>(start, end, param); - EXPECT_DOUBLE_EQ(EvaluateCubicPolynomial(coefs, 0.0, 0, param, 1.0), start[0]); - EXPECT_DOUBLE_EQ(EvaluateCubicPolynomial(coefs, 0.0, 1, param, 1.0), start[1]); - - EXPECT_DOUBLE_EQ(EvaluateCubicPolynomial(coefs, param, 0, param, 1.0), end[0]); - EXPECT_DOUBLE_EQ(EvaluateCubicPolynomial(coefs, param, 1, param, 1.0), end[1]); + EXPECT_DOUBLE_EQ(EvaluateCubicPolynomial(coefs, 0.0, 0, param, 1.0), + start[0]); + EXPECT_DOUBLE_EQ(EvaluateCubicPolynomial(coefs, 0.0, 1, param, 1.0), + start[1]); + + EXPECT_DOUBLE_EQ(EvaluateCubicPolynomial(coefs, param, 0, param, 1.0), + end[0]); + EXPECT_DOUBLE_EQ(EvaluateCubicPolynomial(coefs, param, 1, param, 1.0), + end[1]); } } // namespace math_util diff --git a/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc b/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc index eaadaedb4f140663ab44ee00a9083326af61abfa..b36b6ec74585ba3dbd8696882bee663d18fc61cc 100644 --- a/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc +++ b/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc @@ -21,7 +21,7 @@ namespace apollo { namespace prediction { CruiseMLPEvaluator::CruiseMLPEvaluator() { - LoadModel(FLAGS_evaluator_vehicle_junction_mlp_file); + LoadModel(FLAGS_evaluator_vehicle_cruise_mlp_file); } void CruiseMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) { diff --git a/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.h b/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.h index 31d64f8cea4575f10c66fdce371a00809764acf6..43f79621bbfa84bac30034bda6deb15c353aecc3 100644 --- a/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.h +++ b/modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.h @@ -59,7 +59,6 @@ class CruiseMLPEvaluator : public Evaluator { double ComputeFinishTime(const std::vector& feature_values); private: - std::unordered_map> junction_exit_lane_ids_; }; } // namespace prediction