提交 09a602ba 编写于 作者: P panjiacheng 提交者: Jiangtao Hu

Prediction: added all available features.

上级 4ff9448e
......@@ -255,7 +255,7 @@ def data_preprocessing(data):
# mask out those that don't have any history
mask5 = (data[:,53] != 100)
X = np.concatenate((X_obs_hist_5, X_lane), axis=1)
X = np.concatenate((X_obs_old_features, X_surround_obs, X_obs_hist_5, X_lane), axis=1)
X = X[mask5, :]
y = data[:, -dim_output:]
y = y[mask5, :]
......@@ -398,7 +398,7 @@ def validate_vanilla(valid_X, valid_y, model, batch_size=2048, balance=1.0, pos_
valid_y = valid_y.data.cpu().numpy()
valid_auc = sklearn.metrics.roc_auc_score(valid_y[:,0], pred_y.reshape(-1))
pred_y = (pred_y > 0.5)
pred_y = (pred_y > 0.0)
valid_accuracy = sklearn.metrics.accuracy_score(valid_y[:,0], pred_y.reshape(-1))
valid_precision = sklearn.metrics.precision_score(valid_y[:,0], pred_y.reshape(-1), pos_label=pos_label)
valid_recall = sklearn.metrics.recall_score(valid_y[:,0], pred_y.reshape(-1), pos_label=pos_label)
......
......@@ -115,10 +115,10 @@ class FCNN_CNN1D(torch.nn.Module):
self.lane_feature_dropout = nn.Dropout(0.0)
self.obs_feature_fc = torch.nn.Sequential(\
nn.Linear(24, 17),\
nn.Linear(55, 32),\
nn.Sigmoid(),\
nn.Dropout(0.0),\
nn.Linear(17, 12),\
nn.Linear(32, 24),\
nn.Sigmoid(),\
nn.Dropout(0.0),\
)
......@@ -156,13 +156,13 @@ class FCNN_CNN1D(torch.nn.Module):
nn.ReLU()
)
def forward(self, x):
lane_fea = x[:,24:]
lane_fea = x[:,55:]
lane_fea = lane_fea.view(lane_fea.size(0), 5, 30)
obs_fea = x[:,:24]
obs_fea = x[:,:55]
lane_fea = self.lane_feature_conv(lane_fea)
#print (lane_fea.shape)
lane_fea_max = self.lane_feature_maxpool(lane_fea)
lane_fea_avg = self.lane_feature_avgpool(lane_fea)
......@@ -170,10 +170,10 @@ class FCNN_CNN1D(torch.nn.Module):
lane_fea_avg.view(lane_fea_avg.size(0),-1)], 1)
lane_fea = self.lane_feature_dropout(lane_fea)
#obs_fea = self.obs_feature_fc(obs_fea)
#print (lane_fea.shape)
obs_fea = self.obs_feature_fc(obs_fea)
tot_fea = torch.cat([lane_fea, obs_fea], 1)
out_c = self.classify(tot_fea)
out_r = self.regress(torch.cat([tot_fea, out_c], 1))
return out_c, out_r
\ No newline at end of file
return out_c, out_r
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册