提交 d197c2a0 编写于 作者: K kechxu 提交者: Kecheng Xu

Prediction: cruise mlp model switches to pytorch

上级 36434c55
......@@ -138,6 +138,9 @@ DEFINE_string(evaluator_cruise_vehicle_cutin_model_file,
DEFINE_string(torch_vehicle_junction_mlp_file,
"/apollo/modules/prediction/data/junction_mlp_vehicle_model.pt",
"Vehicle junction MLP model file");
DEFINE_string(torch_vehicle_cruise_cutin_file,
"/apollo/modules/prediction/data/cruise_cutin_vehicle_model.pt",
"Vehicle cruise cutin model file");
DEFINE_string(evaluator_vehicle_junction_mlp_file,
"/apollo/modules/prediction/data/junction_mlp_vehicle_model.bin",
"Vehicle junction MLP model file");
......
......@@ -85,6 +85,7 @@ DECLARE_double(pedestrian_max_acc);
DECLARE_double(still_speed);
DECLARE_string(evaluator_vehicle_mlp_file);
DECLARE_string(torch_vehicle_junction_mlp_file);
DECLARE_string(torch_vehicle_cruise_cutin_file);
DECLARE_string(evaluator_cruise_vehicle_go_model_file);
DECLARE_string(evaluator_cruise_vehicle_cutin_model_file);
DECLARE_string(evaluator_vehicle_rnn_file);
......
......@@ -154,6 +154,7 @@ cc_library(
"//modules/prediction/container:container_manager",
"//modules/prediction/evaluator",
"//modules/prediction/network/cruise_model",
"@pytorch",
],
)
......
......@@ -17,6 +17,9 @@
#include <limits>
#include <utility>
#include "torch/script.h"
#include "torch/torch.h"
#include "cyber/common/file.h"
#include "modules/prediction/common/feature_output.h"
#include "modules/prediction/common/prediction_gflags.h"
......@@ -99,20 +102,27 @@ void CruiseMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) {
return; // Skip Compute probability for offline mode
}
Eigen::MatrixXf obs_feature_mat =
VectorToMatrixXf(feature_values, 0, OBSTACLE_FEATURE_SIZE);
Eigen::MatrixXf lane_feature_mat = VectorToMatrixXf(
feature_values, OBSTACLE_FEATURE_SIZE + INTERACTION_FEATURE_SIZE,
static_cast<int>(feature_values.size()), SINGLE_LANE_FEATURE_SIZE,
LANE_POINTS_SIZE);
Eigen::MatrixXf model_output;
if (lane_sequence_ptr->vehicle_on_lane()) {
go_model_ptr_->Run({lane_feature_mat, obs_feature_mat}, &model_output);
} else {
cutin_model_ptr_->Run({lane_feature_mat, obs_feature_mat}, &model_output);
torch::Device device(torch::kCPU);
// TODO(all) uncomment the following when cuda issue is resolved
// if (torch::cuda::is_available()) {
// ADEBUG << "CUDA is available";
// device = torch::Device(torch::kCUDA);
// }
std::vector<torch::jit::IValue> torch_inputs;
int input_dim = static_cast<int>(OBSTACLE_FEATURE_SIZE +
INTERACTION_FEATURE_SIZE + SINGLE_LANE_FEATURE_SIZE * LANE_POINTS_SIZE);
torch::Tensor torch_input = torch::zeros({1, input_dim});
for (size_t i = 0; i < feature_values.size(); ++i) {
torch_input[0][i] = static_cast<float>(feature_values[i]);
}
double probability = model_output(0, 0);
double finish_time = model_output(0, 1);
torch_inputs.push_back(torch_input.to(device));
std::shared_ptr<torch::jit::script::Module> torch_module =
torch::jit::load(FLAGS_torch_vehicle_cruise_cutin_file, device);
at::Tensor torch_output_tensor =
torch_module->forward(torch_inputs).toTensor();
auto torch_output = torch_output_tensor.accessor<float, 2>();
double probability = static_cast<double>(torch_output[0][0]);
double finish_time = static_cast<double>(torch_output[0][1]);
lane_sequence_ptr->set_probability(probability);
lane_sequence_ptr->set_time_to_lane_center(finish_time);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册