Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Pinoxchio
apollo
提交
26ba385d
A
apollo
项目概览
Pinoxchio
/
apollo
与 Fork 源项目一致
从无法访问的项目Fork
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
apollo
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
26ba385d
编写于
3月 28, 2019
作者:
K
kechxu
提交者:
Kecheng Xu
3月 28, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Prediction: online integrate lane scanning model
上级
a3507e90
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
85 addition
and
4 deletion
+85
-4
modules/prediction/common/prediction_gflags.cc
modules/prediction/common/prediction_gflags.cc
+3
-0
modules/prediction/common/prediction_gflags.h
modules/prediction/common/prediction_gflags.h
+1
-0
modules/prediction/data/lane_scanning_vehicle_model.pt
modules/prediction/data/lane_scanning_vehicle_model.pt
+0
-0
modules/prediction/evaluator/vehicle/BUILD
modules/prediction/evaluator/vehicle/BUILD
+8
-1
modules/prediction/evaluator/vehicle/lane_scanning_evaluator.cc
...s/prediction/evaluator/vehicle/lane_scanning_evaluator.cc
+49
-2
modules/prediction/evaluator/vehicle/lane_scanning_evaluator.h
...es/prediction/evaluator/vehicle/lane_scanning_evaluator.h
+19
-0
modules/prediction/proto/feature.proto
modules/prediction/proto/feature.proto
+5
-1
未找到文件。
modules/prediction/common/prediction_gflags.cc
浏览文件 @
26ba385d
...
...
@@ -140,6 +140,9 @@ DEFINE_string(torch_vehicle_cruise_go_file,
DEFINE_string
(
torch_vehicle_cruise_cutin_file
,
"/apollo/modules/prediction/data/cruise_cutin_vehicle_model.pt"
,
"Vehicle cruise go model file"
);
DEFINE_string
(
torch_vehicle_lane_scanning_file
,
"/apollo/modules/prediction/data/lane_scanning_vehicle_model.pt"
,
"Vehicle lane scanning model file"
);
DEFINE_int32
(
max_num_obstacles
,
300
,
"maximal number of obstacles stored in obstacles container."
);
DEFINE_double
(
valid_position_diff_threshold
,
0.5
,
...
...
modules/prediction/common/prediction_gflags.h
浏览文件 @
26ba385d
...
...
@@ -88,6 +88,7 @@ DECLARE_string(evaluator_vehicle_mlp_file);
DECLARE_string
(
torch_vehicle_junction_mlp_file
);
DECLARE_string
(
torch_vehicle_cruise_go_file
);
DECLARE_string
(
torch_vehicle_cruise_cutin_file
);
DECLARE_string
(
torch_vehicle_lane_scanning_file
);
DECLARE_string
(
evaluator_vehicle_rnn_file
);
DECLARE_string
(
evaluator_vehicle_cruise_mlp_file
);
DECLARE_int32
(
max_num_obstacles
);
...
...
modules/prediction/data/lane_scanning_vehicle_model.pt
0 → 100644
浏览文件 @
26ba385d
文件已添加
modules/prediction/evaluator/vehicle/BUILD
浏览文件 @
26ba385d
...
...
@@ -191,7 +191,14 @@ cc_library(
deps
=
[
"//modules/prediction/container:container_manager"
,
"//modules/prediction/evaluator"
,
],
]
+
select
({
"//tools/platforms:use_gpu"
:
[
"@pytorch"
,
],
"//conditions:default"
:
[
"@pytorch"
,
],
}),
)
cpplint
()
modules/prediction/evaluator/vehicle/lane_scanning_evaluator.cc
浏览文件 @
26ba385d
...
...
@@ -19,6 +19,7 @@
#include <utility>
#include "cyber/common/file.h"
#include "modules/common/proto/pnc_point.pb.h"
#include "modules/prediction/common/feature_output.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/prediction/common/prediction_system_gflags.h"
...
...
@@ -30,9 +31,12 @@ namespace apollo {
namespace
prediction
{
using
apollo
::
common
::
adapter
::
AdapterConfig
;
using
apollo
::
common
::
TrajectoryPoint
;
using
apollo
::
cyber
::
common
::
GetProtoFromFile
;
LaneScanningEvaluator
::
LaneScanningEvaluator
()
{}
LaneScanningEvaluator
::
LaneScanningEvaluator
()
:
device_
(
torch
::
kCPU
)
{
LoadModel
();
}
void
LaneScanningEvaluator
::
Evaluate
(
Obstacle
*
obstacle_ptr
)
{
std
::
vector
<
Obstacle
*>
dummy_dynamic_env
;
...
...
@@ -77,7 +81,16 @@ void LaneScanningEvaluator::Evaluate(Obstacle* obstacle_ptr,
ADEBUG
<<
"Save extracted features for learning locally."
;
return
;
}
// TODO(jiacheng): once the model is trained, implement this online part.
std
::
vector
<
torch
::
jit
::
IValue
>
torch_inputs
;
torch
::
Tensor
torch_input
=
torch
::
zeros
({
1
,
static_cast
<
int
>
(
feature_values
.
size
())});
for
(
size_t
i
=
0
;
i
<
feature_values
.
size
();
++
i
)
{
torch_input
[
0
][
i
]
=
static_cast
<
float
>
(
feature_values
[
i
]);
}
torch_inputs
.
push_back
(
std
::
move
(
torch_input
));
ModelInference
(
torch_inputs
,
torch_lane_scanning_model_ptr_
,
latest_feature_ptr
);
}
bool
LaneScanningEvaluator
::
ExtractFeatures
(
...
...
@@ -294,8 +307,42 @@ bool LaneScanningEvaluator::ExtractStaticEnvFeatures(
}
}
size_t
max_feature_size
=
LANE_POINTS_SIZE
*
SINGLE_LANE_FEATURE_SIZE
*
MAX_NUM_LANE
;
while
(
feature_values
->
size
()
<
max_feature_size
)
{
feature_values
->
push_back
(
0.0
);
}
return
true
;
}
void
LaneScanningEvaluator
::
LoadModel
()
{
// TODO(all) uncomment the following when cuda issue is resolved
// if (torch::cuda::is_available()) {
// ADEBUG << "CUDA is available";
// device_ = torch::Device(torch::kCUDA);
// }
torch
::
set_num_threads
(
1
);
torch_lane_scanning_model_ptr_
=
torch
::
jit
::
load
(
FLAGS_torch_vehicle_lane_scanning_file
,
device_
);
}
void
LaneScanningEvaluator
::
ModelInference
(
const
std
::
vector
<
torch
::
jit
::
IValue
>&
torch_inputs
,
std
::
shared_ptr
<
torch
::
jit
::
script
::
Module
>
torch_model_ptr
,
Feature
*
feature_ptr
)
{
auto
torch_output_tensor
=
torch_model_ptr
->
forward
(
torch_inputs
).
toTensor
();
auto
torch_output
=
torch_output_tensor
.
accessor
<
float
,
2
>
();
for
(
size_t
i
=
0
;
i
<
SHORT_TERM_TRAJECTORY_SIZE
;
++
i
)
{
TrajectoryPoint
point
;
double
x
=
static_cast
<
double
>
(
torch_output
[
0
][
2
*
i
]);
double
y
=
static_cast
<
double
>
(
torch_output
[
0
][
2
*
i
+
1
]);
point
.
mutable_path_point
()
->
set_x
(
x
);
point
.
mutable_path_point
()
->
set_y
(
y
);
feature_ptr
->
add_short_term_predicted_trajectory_points
()
->
CopyFrom
(
point
);
}
}
}
// namespace prediction
}
// namespace apollo
modules/prediction/evaluator/vehicle/lane_scanning_evaluator.h
浏览文件 @
26ba385d
...
...
@@ -20,6 +20,9 @@
#include <string>
#include <vector>
#include "torch/script.h"
#include "torch/torch.h"
#include "modules/prediction/evaluator/evaluator.h"
namespace
apollo
{
...
...
@@ -67,6 +70,11 @@ class LaneScanningEvaluator : public Evaluator {
std
::
string
GetName
()
override
{
return
"LANE_SCANNING_EVALUATOR"
;
}
private:
/**
* @brief Load model from file
*/
void
LoadModel
();
/**
* @brief Extract the features for obstacles
* @param Obstacle pointer
...
...
@@ -84,11 +92,22 @@ class LaneScanningEvaluator : public Evaluator {
const
LaneGraph
*
lane_graph_ptr
,
std
::
vector
<
double
>*
feature_values
);
void
ModelInference
(
const
std
::
vector
<
torch
::
jit
::
IValue
>&
torch_inputs
,
std
::
shared_ptr
<
torch
::
jit
::
script
::
Module
>
torch_model_ptr
,
Feature
*
feature_ptr
);
private:
static
const
size_t
OBSTACLE_FEATURE_SIZE
=
5
*
9
;
static
const
size_t
INTERACTION_FEATURE_SIZE
=
8
;
static
const
size_t
SINGLE_LANE_FEATURE_SIZE
=
4
;
static
const
size_t
LANE_POINTS_SIZE
=
100
;
// (100 * 0.2m = 20m)
static
const
size_t
MAX_NUM_LANE
=
10
;
static
const
size_t
SHORT_TERM_TRAJECTORY_SIZE
=
10
;
std
::
shared_ptr
<
torch
::
jit
::
script
::
Module
>
torch_lane_scanning_model_ptr_
=
nullptr
;
torch
::
Device
device_
;
};
}
// namespace prediction
...
...
modules/prediction/proto/feature.proto
浏览文件 @
26ba385d
...
...
@@ -3,6 +3,7 @@ syntax = "proto2";
package
apollo
.
prediction
;
import
"modules/common/proto/geometry.proto"
;
import
"modules/common/proto/pnc_point.proto"
;
import
"modules/perception/proto/perception_obstacle.proto"
;
import
"modules/prediction/proto/lane_graph.proto"
;
import
"modules/prediction/proto/prediction_point.proto"
;
...
...
@@ -63,7 +64,7 @@ message ObstaclePriority {
optional
Priority
priority
=
25
[
default
=
NORMAL
];
}
// next id = 3
2
// next id = 3
3
message
Feature
{
// Obstacle ID
optional
int32
id
=
1
;
...
...
@@ -107,6 +108,9 @@ message Feature {
// Obstacle ground-truth labels:
repeated
PredictionTrajectoryPoint
future_trajectory_points
=
31
;
// Obstacle short-term predicted trajectory points
repeated
common.TrajectoryPoint
short_term_predicted_trajectory_points
=
32
;
}
message
ObstacleHistory
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录