Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Pinoxchio
apollo
提交
d197c2a0
A
apollo
项目概览
Pinoxchio
/
apollo
与 Fork 源项目一致
从无法访问的项目Fork
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
apollo
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d197c2a0
编写于
3月 06, 2019
作者:
K
kechxu
提交者:
Kecheng Xu
3月 10, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Prediction: cruise mlp model switches to pytorch
上级
36434c55
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
28 addition
and
13 deletion
+28
-13
modules/prediction/common/prediction_gflags.cc
modules/prediction/common/prediction_gflags.cc
+3
-0
modules/prediction/common/prediction_gflags.h
modules/prediction/common/prediction_gflags.h
+1
-0
modules/prediction/data/cruise_cutin_vehicle_model.pt
modules/prediction/data/cruise_cutin_vehicle_model.pt
+0
-0
modules/prediction/data/cruise_go_vehicle_model.pt
modules/prediction/data/cruise_go_vehicle_model.pt
+0
-0
modules/prediction/evaluator/vehicle/BUILD
modules/prediction/evaluator/vehicle/BUILD
+1
-0
modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc
modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc
+23
-13
未找到文件。
modules/prediction/common/prediction_gflags.cc
浏览文件 @
d197c2a0
...
...
@@ -138,6 +138,9 @@ DEFINE_string(evaluator_cruise_vehicle_cutin_model_file,
DEFINE_string
(
torch_vehicle_junction_mlp_file
,
"/apollo/modules/prediction/data/junction_mlp_vehicle_model.pt"
,
"Vehicle junction MLP model file"
);
DEFINE_string
(
torch_vehicle_cruise_cutin_file
,
"/apollo/modules/prediction/data/cruise_cutin_vehicle_model.pt"
,
"Vehicle cruise cutin model file"
);
DEFINE_string
(
evaluator_vehicle_junction_mlp_file
,
"/apollo/modules/prediction/data/junction_mlp_vehicle_model.bin"
,
"Vehicle junction MLP model file"
);
...
...
modules/prediction/common/prediction_gflags.h
浏览文件 @
d197c2a0
...
...
@@ -85,6 +85,7 @@ DECLARE_double(pedestrian_max_acc);
DECLARE_double
(
still_speed
);
DECLARE_string
(
evaluator_vehicle_mlp_file
);
DECLARE_string
(
torch_vehicle_junction_mlp_file
);
DECLARE_string
(
torch_vehicle_cruise_cutin_file
);
DECLARE_string
(
evaluator_cruise_vehicle_go_model_file
);
DECLARE_string
(
evaluator_cruise_vehicle_cutin_model_file
);
DECLARE_string
(
evaluator_vehicle_rnn_file
);
...
...
modules/prediction/data/cruise_cutin_vehicle_model.pt
0 → 100644
浏览文件 @
d197c2a0
文件已添加
modules/prediction/data/cruise_go_vehicle_model.pt
0 → 100644
浏览文件 @
d197c2a0
文件已添加
modules/prediction/evaluator/vehicle/BUILD
浏览文件 @
d197c2a0
...
...
@@ -154,6 +154,7 @@ cc_library(
"//modules/prediction/container:container_manager"
,
"//modules/prediction/evaluator"
,
"//modules/prediction/network/cruise_model"
,
"@pytorch"
,
],
)
...
...
modules/prediction/evaluator/vehicle/cruise_mlp_evaluator.cc
浏览文件 @
d197c2a0
...
...
@@ -17,6 +17,9 @@
#include <limits>
#include <utility>
#include "torch/script.h"
#include "torch/torch.h"
#include "cyber/common/file.h"
#include "modules/prediction/common/feature_output.h"
#include "modules/prediction/common/prediction_gflags.h"
...
...
@@ -99,20 +102,27 @@ void CruiseMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) {
return
;
// Skip Compute probability for offline mode
}
Eigen
::
MatrixXf
obs_feature_mat
=
VectorToMatrixXf
(
feature_values
,
0
,
OBSTACLE_FEATURE_SIZE
);
Eigen
::
MatrixXf
lane_feature_mat
=
VectorToMatrixXf
(
feature_values
,
OBSTACLE_FEATURE_SIZE
+
INTERACTION_FEATURE_SIZE
,
static_cast
<
int
>
(
feature_values
.
size
()),
SINGLE_LANE_FEATURE_SIZE
,
LANE_POINTS_SIZE
);
Eigen
::
MatrixXf
model_output
;
if
(
lane_sequence_ptr
->
vehicle_on_lane
())
{
go_model_ptr_
->
Run
({
lane_feature_mat
,
obs_feature_mat
},
&
model_output
);
}
else
{
cutin_model_ptr_
->
Run
({
lane_feature_mat
,
obs_feature_mat
},
&
model_output
);
torch
::
Device
device
(
torch
::
kCPU
);
// TODO(all) uncomment the following when cuda issue is resolved
// if (torch::cuda::is_available()) {
// ADEBUG << "CUDA is available";
// device = torch::Device(torch::kCUDA);
// }
std
::
vector
<
torch
::
jit
::
IValue
>
torch_inputs
;
int
input_dim
=
static_cast
<
int
>
(
OBSTACLE_FEATURE_SIZE
+
INTERACTION_FEATURE_SIZE
+
SINGLE_LANE_FEATURE_SIZE
*
LANE_POINTS_SIZE
);
torch
::
Tensor
torch_input
=
torch
::
zeros
({
1
,
input_dim
});
for
(
size_t
i
=
0
;
i
<
feature_values
.
size
();
++
i
)
{
torch_input
[
0
][
i
]
=
static_cast
<
float
>
(
feature_values
[
i
]);
}
double
probability
=
model_output
(
0
,
0
);
double
finish_time
=
model_output
(
0
,
1
);
torch_inputs
.
push_back
(
torch_input
.
to
(
device
));
std
::
shared_ptr
<
torch
::
jit
::
script
::
Module
>
torch_module
=
torch
::
jit
::
load
(
FLAGS_torch_vehicle_cruise_cutin_file
,
device
);
at
::
Tensor
torch_output_tensor
=
torch_module
->
forward
(
torch_inputs
).
toTensor
();
auto
torch_output
=
torch_output_tensor
.
accessor
<
float
,
2
>
();
double
probability
=
static_cast
<
double
>
(
torch_output
[
0
][
0
]);
double
finish_time
=
static_cast
<
double
>
(
torch_output
[
0
][
1
]);
lane_sequence_ptr
->
set_probability
(
probability
);
lane_sequence_ptr
->
set_time_to_lane_center
(
finish_time
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录