Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Pinoxchio
apollo
提交
09a602ba
A
apollo
项目概览
Pinoxchio
/
apollo
与 Fork 源项目一致
从无法访问的项目Fork
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
apollo
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
09a602ba
编写于
11月 08, 2018
作者:
P
panjiacheng
提交者:
Jiangtao Hu
12月 13, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Prediction: added all available features.
上级
4ff9448e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
10 addition
and
10 deletion
+10
-10
modules/tools/prediction/mlp_train/cruiseMLP_train.py
modules/tools/prediction/mlp_train/cruiseMLP_train.py
+2
-2
modules/tools/prediction/mlp_train/cruise_models.py
modules/tools/prediction/mlp_train/cruise_models.py
+8
-8
未找到文件。
modules/tools/prediction/mlp_train/cruiseMLP_train.py
浏览文件 @
09a602ba
...
...
@@ -255,7 +255,7 @@ def data_preprocessing(data):
# mask out those that don't have any history
mask5
=
(
data
[:,
53
]
!=
100
)
X
=
np
.
concatenate
((
X_obs_hist_5
,
X_lane
),
axis
=
1
)
X
=
np
.
concatenate
((
X_obs_
old_features
,
X_surround_obs
,
X_obs_
hist_5
,
X_lane
),
axis
=
1
)
X
=
X
[
mask5
,
:]
y
=
data
[:,
-
dim_output
:]
y
=
y
[
mask5
,
:]
...
...
@@ -398,7 +398,7 @@ def validate_vanilla(valid_X, valid_y, model, batch_size=2048, balance=1.0, pos_
valid_y
=
valid_y
.
data
.
cpu
().
numpy
()
valid_auc
=
sklearn
.
metrics
.
roc_auc_score
(
valid_y
[:,
0
],
pred_y
.
reshape
(
-
1
))
pred_y
=
(
pred_y
>
0.
5
)
pred_y
=
(
pred_y
>
0.
0
)
valid_accuracy
=
sklearn
.
metrics
.
accuracy_score
(
valid_y
[:,
0
],
pred_y
.
reshape
(
-
1
))
valid_precision
=
sklearn
.
metrics
.
precision_score
(
valid_y
[:,
0
],
pred_y
.
reshape
(
-
1
),
pos_label
=
pos_label
)
valid_recall
=
sklearn
.
metrics
.
recall_score
(
valid_y
[:,
0
],
pred_y
.
reshape
(
-
1
),
pos_label
=
pos_label
)
...
...
modules/tools/prediction/mlp_train/cruise_models.py
浏览文件 @
09a602ba
...
...
@@ -115,10 +115,10 @@ class FCNN_CNN1D(torch.nn.Module):
self
.
lane_feature_dropout
=
nn
.
Dropout
(
0.0
)
self
.
obs_feature_fc
=
torch
.
nn
.
Sequential
(
\
nn
.
Linear
(
24
,
17
),
\
nn
.
Linear
(
55
,
32
),
\
nn
.
Sigmoid
(),
\
nn
.
Dropout
(
0.0
),
\
nn
.
Linear
(
17
,
12
),
\
nn
.
Linear
(
32
,
24
),
\
nn
.
Sigmoid
(),
\
nn
.
Dropout
(
0.0
),
\
)
...
...
@@ -156,13 +156,13 @@ class FCNN_CNN1D(torch.nn.Module):
nn
.
ReLU
()
)
def
forward
(
self
,
x
):
lane_fea
=
x
[:,
24
:]
lane_fea
=
x
[:,
55
:]
lane_fea
=
lane_fea
.
view
(
lane_fea
.
size
(
0
),
5
,
30
)
obs_fea
=
x
[:,:
24
]
obs_fea
=
x
[:,:
55
]
lane_fea
=
self
.
lane_feature_conv
(
lane_fea
)
#print (lane_fea.shape)
lane_fea_max
=
self
.
lane_feature_maxpool
(
lane_fea
)
lane_fea_avg
=
self
.
lane_feature_avgpool
(
lane_fea
)
...
...
@@ -170,10 +170,10 @@ class FCNN_CNN1D(torch.nn.Module):
lane_fea_avg
.
view
(
lane_fea_avg
.
size
(
0
),
-
1
)],
1
)
lane_fea
=
self
.
lane_feature_dropout
(
lane_fea
)
#
obs_fea = self.obs_feature_fc(obs_fea)
#print (lane_fea.shape)
obs_fea
=
self
.
obs_feature_fc
(
obs_fea
)
tot_fea
=
torch
.
cat
([
lane_fea
,
obs_fea
],
1
)
out_c
=
self
.
classify
(
tot_fea
)
out_r
=
self
.
regress
(
torch
.
cat
([
tot_fea
,
out_c
],
1
))
return
out_c
,
out_r
\ No newline at end of file
return
out_c
,
out_r
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录