提交 76121f3c 编写于 作者: P panjiacheng 提交者: Jiangtao Hu

Prediction: also predicting arrival time.

上级 d452c9d2
......@@ -297,14 +297,14 @@ def loss_fn(c_pred, r_pred, target):
loss_C = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([1.0]).cuda()) #nn.BCELoss()
loss_R = nn.MSELoss()
loss = loss_C(c_pred, target[:,0].view(target.shape[0],1))
#loss = 4 * loss_C(c_pred, target[:,0].view(target.shape[0],1)) + \
#loss = loss_C(c_pred, target[:,0].view(target.shape[0],1))
loss = 4 * loss_C(c_pred, target[:,0].view(target.shape[0],1)) + \
loss_R(((target[:,2] > 0.0) * (target[:,2] <= 3.0)).float().view(target.shape[0],1) * r_pred + \
((target[:,2] <= 0.0) + (target[:,2] > 3.0)).float().view(target.shape[0],1) * target[:,2].view(target.shape[0],1), \
target[:,2].view(target.shape[0],1))
#loss_R((target[:,1] < 10.0).float().view(target.shape[0],1) * r_pred + \
# (target[:,1] >= 10.0).float().view(target.shape[0],1) * target[:,1].view(target.shape[0],1), \
# target[:,1].view(target.shape[0],1))
#loss_R((target[:,0] == True).float().view(target.shape[0],1) * r_pred + \
# (target[:,0] == False).float().view(target.shape[0],1) * target[:,1].view(target.shape[0],1), \
# target[:,1].view(target.shape[0],1))
return loss
......@@ -677,8 +677,7 @@ if __name__ == "__main__":
valid_loss = validate_vanilla(X_valid, y_valid, model)
scheduler.step(valid_loss)
if valid_loss < best_valid_loss:
torch.
#torch.save(model.state_dict(), './cruiseMLP_saved_model.pt')
torch.save(model.state_dict(), './cruiseMLP_saved_model.pt')
else:
train_dir = args.train_file
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册