提交 ef8fc7f9 编写于 作者: K kechxu 提交者: Kecheng Xu

Prediction: add cruise torch go model

上级 3de38725
......@@ -138,9 +138,12 @@ DEFINE_string(evaluator_cruise_vehicle_cutin_model_file,
DEFINE_string(torch_vehicle_junction_mlp_file,
"/apollo/modules/prediction/data/junction_mlp_vehicle_model.pt",
"Vehicle junction MLP model file");
DEFINE_string(torch_vehicle_cruise_go_file,
"/apollo/modules/prediction/data/cruise_go_vehicle_model.pt",
"Vehicle cruise cutin model file");
DEFINE_string(torch_vehicle_cruise_cutin_file,
"/apollo/modules/prediction/data/cruise_cutin_vehicle_model.pt",
"Vehicle cruise cutin model file");
"Vehicle cruise go model file");
DEFINE_string(evaluator_vehicle_junction_mlp_file,
"/apollo/modules/prediction/data/junction_mlp_vehicle_model.bin",
"Vehicle junction MLP model file");
......
......@@ -85,6 +85,7 @@ DECLARE_double(pedestrian_max_acc);
DECLARE_double(still_speed);
DECLARE_string(evaluator_vehicle_mlp_file);
DECLARE_string(torch_vehicle_junction_mlp_file);
DECLARE_string(torch_vehicle_cruise_go_file);
DECLARE_string(torch_vehicle_cruise_cutin_file);
DECLARE_string(evaluator_cruise_vehicle_go_model_file);
DECLARE_string(evaluator_cruise_vehicle_cutin_model_file);
......
......@@ -108,14 +108,25 @@ void CruiseMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) {
torch_input[0][i] = static_cast<float>(feature_values[i]);
}
torch_inputs.push_back(torch_input.to(device_));
auto torch_output_tuple =
torch_cutin_model_ptr_->forward(torch_inputs).toTuple();
auto probability_tensor = torch_output_tuple->elements()[0].toTensor();
auto finish_time_tensor = torch_output_tuple->elements()[1].toTensor();
lane_sequence_ptr->set_probability(Sigmoid(static_cast<double>(
probability_tensor.accessor<float, 2>()[0][0])));
lane_sequence_ptr->set_time_to_lane_center(
static_cast<double>(finish_time_tensor.accessor<float, 2>()[0][0]));
if (lane_sequence_ptr->vehicle_on_lane()) {
auto torch_output_tuple =
torch_go_model_ptr_->forward(torch_inputs).toTuple();
auto probability_tensor = torch_output_tuple->elements()[0].toTensor();
auto finish_time_tensor = torch_output_tuple->elements()[1].toTensor();
lane_sequence_ptr->set_probability(Sigmoid(static_cast<double>(
probability_tensor.accessor<float, 2>()[0][0])));
lane_sequence_ptr->set_time_to_lane_center(
static_cast<double>(finish_time_tensor.accessor<float, 2>()[0][0]));
} else {
auto torch_output_tuple =
torch_cutin_model_ptr_->forward(torch_inputs).toTuple();
auto probability_tensor = torch_output_tuple->elements()[0].toTensor();
auto finish_time_tensor = torch_output_tuple->elements()[1].toTensor();
lane_sequence_ptr->set_probability(Sigmoid(static_cast<double>(
probability_tensor.accessor<float, 2>()[0][0])));
lane_sequence_ptr->set_time_to_lane_center(
static_cast<double>(finish_time_tensor.accessor<float, 2>()[0][0]));
}
}
}
......@@ -140,20 +151,22 @@ void CruiseMLPEvaluator::ExtractFeatureValues(
feature_values->insert(feature_values->end(), obstacle_feature_values.begin(),
obstacle_feature_values.end());
// Extract interaction features.
// std::vector<double> interaction_feature_values;
// SetInteractionFeatureValues(obstacle_ptr, lane_sequence_ptr,
// &interaction_feature_values);
// if (interaction_feature_values.size() != INTERACTION_FEATURE_SIZE) {
// ADEBUG << "Obstacle [" << id << "] has fewer than "
// << "expected lane feature_values"
// << interaction_feature_values.size() << ".";
// return;
// }
// ADEBUG << "Interaction feature size = " << interaction_feature_values.size();
// feature_values->insert(feature_values->end(),
// interaction_feature_values.begin(),
// interaction_feature_values.end());
/*
Extract interaction features.
std::vector<double> interaction_feature_values;
SetInteractionFeatureValues(obstacle_ptr, lane_sequence_ptr,
&interaction_feature_values);
if (interaction_feature_values.size() != INTERACTION_FEATURE_SIZE) {
ADEBUG << "Obstacle [" << id << "] has fewer than "
<< "expected lane feature_values"
<< interaction_feature_values.size() << ".";
return;
}
ADEBUG << "Interaction feature size = " << interaction_feature_values.size();
feature_values->insert(feature_values->end(),
interaction_feature_values.begin(),
interaction_feature_values.end());
*/
// Extract lane related features.
std::vector<double> lane_feature_values;
......@@ -534,6 +547,8 @@ void CruiseMLPEvaluator::LoadModels() {
// ADEBUG << "CUDA is available";
// device_ = torch::Device(torch::kCUDA);
// }
torch_go_model_ptr_ = torch::jit::load(
FLAGS_torch_vehicle_cruise_go_file, device_);
torch_cutin_model_ptr_ = torch::jit::load(
FLAGS_torch_vehicle_cruise_cutin_file, device_);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册