提交 3e3ead38 编写于 作者: K kechxu 提交者: Jiangtao Hu

Prediction: implement cruise_mlp_evaluator

上级 09a602ba
......@@ -121,9 +121,12 @@ DEFINE_string(evaluator_vehicle_mlp_file,
DEFINE_string(evaluator_vehicle_rnn_file,
"/apollo/modules/prediction/data/rnn_vehicle_model.bin",
"rnn model file for vehicle evaluator");
DEFINE_string(evaluator_vehicle_cruise_mlp_file,
"/apollo/modules/prediction/data/cruise_mlp_vehicle_model.bin",
"Vehicle cruise MLP model file");
DEFINE_string(evaluator_cruise_vehicle_go_model_file,
"/apollo/modules/prediction/data/cruise_vehicle_go_model.bin",
"Vehicle cruise go model file");
DEFINE_string(evaluator_cruise_vehicle_cutin_model_file,
"/apollo/modules/prediction/data/cruise_vehicle_cutin_model.bin",
"Vehicle cruise cutin model file");
DEFINE_string(evaluator_vehicle_junction_mlp_file,
"/apollo/modules/prediction/data/junction_mlp_vehicle_model.bin",
"Vehicle junction MLP model file");
......
......@@ -81,6 +81,8 @@ DECLARE_double(pedestrian_max_acc);
DECLARE_double(prediction_pedestrian_total_time);
DECLARE_double(still_speed);
DECLARE_string(evaluator_vehicle_mlp_file);
DECLARE_string(evaluator_cruise_vehicle_go_model_file);
DECLARE_string(evaluator_cruise_vehicle_cutin_model_file);
DECLARE_string(evaluator_vehicle_rnn_file);
DECLARE_string(evaluator_vehicle_cruise_mlp_file);
DECLARE_string(evaluator_vehicle_junction_mlp_file);
......
......@@ -172,6 +172,7 @@ cc_library(
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/common:prediction_util",
"//modules/prediction/common:validation_checker",
"//modules/prediction/network/cruise_model:cruise_model",
"//modules/prediction/container:container_manager",
"//modules/prediction/container/obstacles:obstacle",
"//modules/prediction/evaluator",
......
......@@ -17,6 +17,8 @@
#include <limits>
#include <utility>
#include "Eigen/Dense"
#include "modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.h"
#include "modules/common/math/math_utils.h"
#include "modules/common/util/file.h"
......@@ -33,6 +35,7 @@ namespace apollo {
namespace prediction {
using apollo::common::adapter::AdapterConfig;
using apollo::common::util::GetProtoFromFile;
// Helper function for computing the mean value of a vector.
double ComputeMean(const std::vector<double>& nums, size_t start, size_t end) {
......@@ -45,6 +48,19 @@ double ComputeMean(const std::vector<double>& nums, size_t start, size_t end) {
return (count == 0) ? 0.0 : sum / count;
}
Eigen::MatrixXf VectorToMatrixXf(const std::vector<double>& nums,
const int start_index, const int end_index) {
CHECK_LT(start_index, end_index);
CHECK_GE(start_index, 0);
CHECK_LE(end_index, static_cast<int>(nums.size()));
Eigen::MatrixXf output_matrix;
output_matrix.resize(1, end_index - start_index);
for (int i = start_index; i < end_index; ++i) {
output_matrix(1, i - start_index) = static_cast<float>(nums[i]);
}
return output_matrix;
}
// Helper function for converting world coordinate to relative coordinate
// with respect to the object (obstacle or ADC)
std::pair<double, double> WorldCoordToObjCoord
......@@ -65,7 +81,9 @@ double WorldAngleToObjAngle(double input_world_angle,
}
CruiseMLPEvaluator::CruiseMLPEvaluator() {
LoadModel(FLAGS_evaluator_vehicle_cruise_mlp_file);
// TODO(kechxu) name go and cutin models
LoadModels(FLAGS_evaluator_cruise_vehicle_go_model_file,
FLAGS_evaluator_cruise_vehicle_go_model_file);
}
void CruiseMLPEvaluator::Clear() {
......@@ -102,7 +120,20 @@ void CruiseMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) {
CHECK_NOTNULL(lane_sequence_ptr);
std::vector<double> feature_values;
ExtractFeatureValues(obstacle_ptr, lane_sequence_ptr, &feature_values);
double finish_time = ComputeFinishTime(feature_values);
Eigen::MatrixXf obs_feature_mat = VectorToMatrixXf(feature_values, 0,
OBSTACLE_FEATURE_SIZE + INTERACTION_FEATURE_SIZE);
Eigen::MatrixXf lane_feature_mat = VectorToMatrixXf(feature_values,
OBSTACLE_FEATURE_SIZE + INTERACTION_FEATURE_SIZE,
static_cast<int>(feature_values.size()));
Eigen::MatrixXf model_output;
if (lane_sequence_ptr->vehicle_on_lane()) {
go_model_ptr_->Run({lane_feature_mat, obs_feature_mat}, &model_output);
} else {
cutin_model_ptr_->Run({lane_feature_mat, obs_feature_mat}, &model_output);
}
double probability = model_output(0, 0);
double finish_time = model_output(0, 1);
lane_sequence_ptr->set_probability(probability);
lane_sequence_ptr->set_time_to_lane_center(finish_time);
}
......@@ -525,15 +556,25 @@ void CruiseMLPEvaluator::SetLaneFeatureValues
}
// TODO(all): uncomment this once the model is trained and ready.
void CruiseMLPEvaluator::LoadModel(const std::string& model_file) {
// Currently, it's using FnnVehicleModel
// TODO(all) implement it using the generic "network" class.
// model_ptr_.reset(new FnnVehicleModel());
// CHECK(model_ptr_ != nullptr);
// CHECK(common::util::GetProtoFromFile(model_file, model_ptr_.get()))
// << "Unable to load model file: " << model_file << ".";
// AINFO << "Succeeded in loading the model file: " << model_file << ".";
void CruiseMLPEvaluator::LoadModels(const std::string& go_model_file,
const std::string& cutin_model_file) {
go_model_ptr_.reset(new network::CruiseModel());
cutin_model_ptr_.reset(new network::CruiseModel());
CHECK_NOTNULL(go_model_ptr_);
CHECK_NOTNULL(cutin_model_ptr_);
CruiseModelParameter go_model_param;
CruiseModelParameter cutin_model_param;
CHECK(GetProtoFromFile(go_model_file, &go_model_param))
<< "Unable to load go model file: " << go_model_file << ".";
CHECK(GetProtoFromFile(cutin_model_file, &cutin_model_param))
<< "Unable to load cutin model file: " << cutin_model_file << ".";
go_model_ptr_->LoadModel(go_model_param);
cutin_model_ptr_->LoadModel(cutin_model_param);
ADEBUG << "Succeeded in loading go model: " << go_model_file << ".";
ADEBUG << "Succeeded in loading cutin model: " << cutin_model_file << ".";
}
// TODO(all): implement this once the model is trained and ready.
......
......@@ -26,6 +26,7 @@
#include "modules/prediction/proto/feature.pb.h"
#include "modules/prediction/proto/fnn_vehicle_model.pb.h"
#include "modules/prediction/proto/lane_graph.pb.h"
#include "modules/prediction/network/cruise_model/cruise_model.h"
namespace apollo {
namespace prediction {
......@@ -88,10 +89,12 @@ class CruiseMLPEvaluator : public Evaluator {
std::vector<double>* feature_values);
/**
* @brief Load mode file
* @param Model file name
* @brief Load mode files
* @param Go model file name
* @param Cutin model file name
*/
void LoadModel(const std::string& model_file);
void LoadModels(const std::string& go_model_file,
const std::string& cutin_model_file);
/**
* @brief Compute probability of a junction exit
......@@ -110,7 +113,9 @@ class CruiseMLPEvaluator : public Evaluator {
static const size_t OBSTACLE_FEATURE_SIZE = 23 + 60;
static const size_t INTERACTION_FEATURE_SIZE = 8;
static const size_t LANE_FEATURE_SIZE = 150;
std::unique_ptr<FnnVehicleModel> model_ptr_;
std::shared_ptr<network::CruiseModel> go_model_ptr_;
std::shared_ptr<network::CruiseModel> cutin_model_ptr_;
};
} // namespace prediction
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册