提交 ec03f537 编写于 作者: K kechxu 提交者: Kecheng Xu

Prediction: retire unused code and data

上级 ef8fc7f9
......@@ -129,12 +129,6 @@ DEFINE_string(evaluator_vehicle_mlp_file,
DEFINE_string(evaluator_vehicle_rnn_file,
"/apollo/modules/prediction/data/rnn_vehicle_model.bin",
"rnn model file for vehicle evaluator");
DEFINE_string(evaluator_cruise_vehicle_go_model_file,
"/apollo/modules/prediction/data/cruise_go_vehicle_model.bin",
"Vehicle cruise go model file");
DEFINE_string(evaluator_cruise_vehicle_cutin_model_file,
"/apollo/modules/prediction/data/cruise_cutin_vehicle_model.bin",
"Vehicle cruise cut-in model file");
DEFINE_string(torch_vehicle_junction_mlp_file,
"/apollo/modules/prediction/data/junction_mlp_vehicle_model.pt",
"Vehicle junction MLP model file");
......@@ -144,9 +138,6 @@ DEFINE_string(torch_vehicle_cruise_go_file,
DEFINE_string(torch_vehicle_cruise_cutin_file,
"/apollo/modules/prediction/data/cruise_cutin_vehicle_model.pt",
"Vehicle cruise go model file");
DEFINE_string(evaluator_vehicle_junction_mlp_file,
"/apollo/modules/prediction/data/junction_mlp_vehicle_model.bin",
"Vehicle junction MLP model file");
DEFINE_int32(max_num_obstacles, 300,
"maximal number of obstacles stored in obstacles container.");
DEFINE_double(valid_position_diff_threshold, 0.5,
......
......@@ -87,11 +87,8 @@ DECLARE_string(evaluator_vehicle_mlp_file);
DECLARE_string(torch_vehicle_junction_mlp_file);
DECLARE_string(torch_vehicle_cruise_go_file);
DECLARE_string(torch_vehicle_cruise_cutin_file);
DECLARE_string(evaluator_cruise_vehicle_go_model_file);
DECLARE_string(evaluator_cruise_vehicle_cutin_model_file);
DECLARE_string(evaluator_vehicle_rnn_file);
DECLARE_string(evaluator_vehicle_cruise_mlp_file);
DECLARE_string(evaluator_vehicle_junction_mlp_file);
DECLARE_int32(max_num_obstacles);
DECLARE_double(valid_position_diff_threshold);
DECLARE_double(valid_position_diff_rate_threshold);
......
......@@ -109,23 +109,9 @@ void CruiseMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) {
}
torch_inputs.push_back(torch_input.to(device_));
if (lane_sequence_ptr->vehicle_on_lane()) {
auto torch_output_tuple =
torch_go_model_ptr_->forward(torch_inputs).toTuple();
auto probability_tensor = torch_output_tuple->elements()[0].toTensor();
auto finish_time_tensor = torch_output_tuple->elements()[1].toTensor();
lane_sequence_ptr->set_probability(Sigmoid(static_cast<double>(
probability_tensor.accessor<float, 2>()[0][0])));
lane_sequence_ptr->set_time_to_lane_center(
static_cast<double>(finish_time_tensor.accessor<float, 2>()[0][0]));
ModelInference(torch_inputs, torch_go_model_ptr_, lane_sequence_ptr);
} else {
auto torch_output_tuple =
torch_cutin_model_ptr_->forward(torch_inputs).toTuple();
auto probability_tensor = torch_output_tuple->elements()[0].toTensor();
auto finish_time_tensor = torch_output_tuple->elements()[1].toTensor();
lane_sequence_ptr->set_probability(Sigmoid(static_cast<double>(
probability_tensor.accessor<float, 2>()[0][0])));
lane_sequence_ptr->set_time_to_lane_center(
static_cast<double>(finish_time_tensor.accessor<float, 2>()[0][0]));
ModelInference(torch_inputs, torch_cutin_model_ptr_, lane_sequence_ptr);
}
}
}
......@@ -553,17 +539,18 @@ void CruiseMLPEvaluator::LoadModels() {
FLAGS_torch_vehicle_cruise_cutin_file, device_);
}
// TODO(all): implement this once the model is trained and ready.
double CruiseMLPEvaluator::ComputeFinishTime(
const std::vector<double>& feature_values) {
return 6.0;
}
void CruiseMLPEvaluator::SaveOfflineFeatures(
LaneSequence* sequence, const std::vector<double>& feature_values) {
for (double feature_value : feature_values) {
sequence->mutable_features()->add_mlp_features(feature_value);
}
void CruiseMLPEvaluator::ModelInference(
const std::vector<torch::jit::IValue>& torch_inputs,
std::shared_ptr<torch::jit::script::Module> torch_model_ptr,
LaneSequence* lane_sequence_ptr) {
auto torch_output_tuple =
torch_model_ptr->forward(torch_inputs).toTuple();
auto probability_tensor = torch_output_tuple->elements()[0].toTensor();
auto finish_time_tensor = torch_output_tuple->elements()[1].toTensor();
lane_sequence_ptr->set_probability(Sigmoid(static_cast<double>(
probability_tensor.accessor<float, 2>()[0][0])));
lane_sequence_ptr->set_time_to_lane_center(
static_cast<double>(finish_time_tensor.accessor<float, 2>()[0][0]));
}
} // namespace prediction
......
......@@ -97,18 +97,10 @@ class CruiseMLPEvaluator : public Evaluator {
*/
void LoadModels();
/**
* @brief Compute probability of a junction exit
*/
double ComputeFinishTime(const std::vector<double>& feature_values);
/**
* @brief Save offline feature values in proto
* @param Lane sequence
* @param Vector of feature values
*/
void SaveOfflineFeatures(LaneSequence* sequence,
const std::vector<double>& feature_values);
void ModelInference(
const std::vector<torch::jit::IValue>& torch_inputs,
std::shared_ptr<torch::jit::script::Module> torch_model_ptr,
LaneSequence* lane_sequence_ptr);
private:
static const size_t OBSTACLE_FEATURE_SIZE = 23 + 5 * 9;
......
......@@ -53,7 +53,7 @@ double ComputeMean(const std::vector<double>& nums, size_t start, size_t end) {
} // namespace
JunctionMLPEvaluator::JunctionMLPEvaluator() : device_(torch::kCPU) {
LoadModel(FLAGS_evaluator_vehicle_junction_mlp_file);
LoadModel(FLAGS_torch_vehicle_junction_mlp_file);
}
void JunctionMLPEvaluator::Clear() {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册