提交 765520e0 编写于 作者: P Pavol Mulinka

adjusted ZILNLoss to return mean

上级 612da44a
......@@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 71,
"id": "d40952c1",
"metadata": {},
"outputs": [],
......@@ -41,12 +41,13 @@
"# Pytorch implementation\n",
"* reduction=\"none\" in binary_cross_entropy_with_logits gives same result for given testing usecase as authors implementation https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy_with_logits.html\n",
" * hovewer keras BCE uses auto reduction by default https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy which changes according to usecase : so wtf? https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction\n",
"* keras also return flattened BCE result"
"* keras also return flattened BCE result\n",
"* keras needs loss per sample and it averages it in the backend: https://keras.io/api/losses/"
]
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 83,
"id": "9864678c",
"metadata": {},
"outputs": [],
......@@ -56,6 +57,7 @@
"import torch.nn.functional as F\n",
"from pytorch_widedeep.wdtypes import *\n",
"\n",
"\n",
"def _predict_ziln(preds: Tensor) -> Tensor:\n",
" \"\"\"Calculates predicted mean of zero inflated lognormal logits.\n",
" Adjusted implementaion of `code\n",
......@@ -123,54 +125,33 @@
" positive * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),\n",
" dim=-1)\n",
"\n",
" return classification_loss + regression_loss"
" return torch.mean(classification_loss + regression_loss)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"execution_count": 84,
"id": "d1073cc6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.7444, 1.8784])"
"tensor(1.3114)"
]
},
"execution_count": 65,
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"use_cuda = torch.cuda.is_available()\n",
"target = torch.tensor([[0., 1.5]]).view(-1, 1)\n",
"input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])\n",
"ZILNLoss()(input, target)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "632af7b0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.9236],\n",
" [1.6908]])"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_predict_ziln(input)"
]
},
{
"cell_type": "markdown",
"id": "6f91526c",
......
......@@ -112,7 +112,7 @@ class ZILNLoss(nn.Module):
positive * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),
dim=-1)
return classification_loss + regression_loss
return torch.mean(classification_loss + regression_loss)
class FocalLoss(nn.Module):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册