Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Pinoxchio
apollo
提交
c097b926
A
apollo
项目概览
Pinoxchio
/
apollo
与 Fork 源项目一致
从无法访问的项目Fork
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
apollo
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c097b926
编写于
11月 04, 2018
作者:
P
panjiacheng
提交者:
Jiangtao Hu
12月 13, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Prediction: cruise training using 1d-CNN model seems to have significant improvements.
上级
f546619c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
13 addition
and
12 deletion
+13
-12
modules/tools/prediction/mlp_train/cruiseMLP_train.py
modules/tools/prediction/mlp_train/cruiseMLP_train.py
+13
-12
未找到文件。
modules/tools/prediction/mlp_train/cruiseMLP_train.py
浏览文件 @
c097b926
...
...
@@ -104,7 +104,7 @@ class FCNN_CNN1D(torch.nn.Module):
def
__init__
(
self
):
super
(
FCNN_CNN1D
,
self
).
__init__
()
self
.
lane_feature_conv
=
torch
.
nn
.
Sequential
(
\
nn
.
Conv1d
(
6
,
10
,
3
,
stride
=
1
),
\
nn
.
Conv1d
(
5
,
10
,
3
,
stride
=
1
),
\
nn
.
ReLU
(),
\
nn
.
Conv1d
(
10
,
16
,
3
,
stride
=
2
),
\
nn
.
ReLU
(),
\
...
...
@@ -115,7 +115,7 @@ class FCNN_CNN1D(torch.nn.Module):
self
.
lane_feature_dropout
=
nn
.
Dropout
(
0.0
)
self
.
obs_feature_fc
=
torch
.
nn
.
Sequential
(
\
nn
.
Linear
(
2
3
,
17
),
\
nn
.
Linear
(
2
4
,
17
),
\
nn
.
Sigmoid
(),
\
nn
.
Dropout
(
0.0
),
\
nn
.
Linear
(
17
,
12
),
\
...
...
@@ -124,7 +124,7 @@ class FCNN_CNN1D(torch.nn.Module):
)
self
.
classify
=
torch
.
nn
.
Sequential
(
\
nn
.
Linear
(
12
3
,
66
),
\
nn
.
Linear
(
12
4
,
66
),
\
nn
.
Sigmoid
(),
\
nn
.
Dropout
(
0.3
),
\
...
...
@@ -140,7 +140,7 @@ class FCNN_CNN1D(torch.nn.Module):
#nn.Sigmoid()
)
self
.
regress
=
torch
.
nn
.
Sequential
(
\
nn
.
Linear
(
12
4
,
77
),
\
nn
.
Linear
(
12
5
,
77
),
\
nn
.
ReLU
(),
\
nn
.
Dropout
(
0.2
),
\
...
...
@@ -156,12 +156,13 @@ class FCNN_CNN1D(torch.nn.Module):
nn
.
ReLU
()
)
def
forward
(
self
,
x
):
lane_fea
=
x
[:,
2
3
:]
lane_fea
=
lane_fea
.
view
(
lane_fea
.
size
(
0
),
6
,
30
)
lane_fea
=
x
[:,
2
4
:]
lane_fea
=
lane_fea
.
view
(
lane_fea
.
size
(
0
),
5
,
30
)
obs_fea
=
x
[:,:
2
3
]
obs_fea
=
x
[:,:
2
4
]
lane_fea
=
self
.
lane_feature_conv
(
lane_fea
)
#print (lane_fea.shape)
lane_fea_max
=
self
.
lane_feature_maxpool
(
lane_fea
)
lane_fea_avg
=
self
.
lane_feature_avgpool
(
lane_fea
)
...
...
@@ -346,7 +347,7 @@ def train_vanilla(train_X, train_y, model, optimizer, epoch, batch_size=2048):
c_pred
,
r_pred
=
model
(
X
)
loss
=
loss_fn
(
c_pred
,
r_pred
,
y
)
#loss.data[0].cpu().numpy()
loss_history
.
append
(
loss
.
data
[
0
]
)
loss_history
.
append
(
loss
.
data
)
loss
.
backward
()
optimizer
.
step
()
train_correct_class
+=
\
...
...
@@ -384,7 +385,7 @@ def train_dataloader(train_loader, model, optimizer, epoch):
c_pred
,
r_pred
=
model
(
X
)
loss
=
loss_fn
(
c_pred
,
r_pred
,
y
)
#loss.data[0].cpu().numpy()
loss_history
.
append
(
loss
.
data
[
0
]
)
loss_history
.
append
(
loss
.
data
)
loss
.
backward
()
optimizer
.
step
()
...
...
@@ -419,7 +420,7 @@ def validate_vanilla(valid_X, valid_y, model, batch_size=2048):
y
=
valid_y
[
i
*
batch_size
:
min
(
num_of_data
,
(
i
+
1
)
*
batch_size
),]
c_pred
,
r_pred
=
model
(
X
)
valid_loss
=
loss_fn
(
c_pred
,
r_pred
,
y
)
loss_history
.
append
(
valid_loss
.
data
[
0
]
)
loss_history
.
append
(
valid_loss
.
data
)
c_pred
=
c_pred
.
data
.
cpu
().
numpy
()
c_pred
=
c_pred
.
reshape
(
c_pred
.
shape
[
0
],
1
)
...
...
@@ -472,7 +473,7 @@ def validate_dataloader(valid_loader, model):
y
=
y
.
float
().
cuda
()
c_pred
,
r_pred
=
model
(
X
)
valid_loss
=
loss_fn
(
c_pred
,
r_pred
,
y
)
loss_history
.
append
(
valid_loss
.
data
[
0
]
)
loss_history
.
append
(
valid_loss
.
data
)
valid_correct_class
+=
\
np
.
sum
((
c_pred
.
data
.
cpu
().
numpy
()
>
0.5
).
astype
(
float
)
==
\
y
[:,
0
].
data
.
cpu
().
numpy
().
reshape
(
c_pred
.
data
.
cpu
().
numpy
().
shape
[
0
],
1
))
...
...
@@ -518,7 +519,7 @@ if __name__ == "__main__":
X_train
,
y_train
,
X_valid
,
y_valid
=
data_preprocessing
(
train_data
)
# Model declaration
model
=
F
ullyConn_NN
()
model
=
F
CNN_CNN1D
()
print
(
"The model used is: "
)
print
(
model
)
learning_rate
=
6.561e-4
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录