提交 788f63f1 编写于 作者: P Pavol Mulinka

initial ZILN loss commit

上级 4de81ea2
{
"cells": [
{
"cell_type": "markdown",
"id": "6658e02e",
"metadata": {},
"source": [
"# Losses"
]
},
{
"cell_type": "markdown",
"id": "639ad78b",
"metadata": {},
"source": [
"## Initial imports"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d40952c1",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"\n",
"# increase displayed columns in jupyter notebook\n",
"pd.set_option(\"display.max_columns\", 200)\n",
"pd.set_option(\"display.max_rows\", 300)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "9864678c",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from pytorch_widedeep.wdtypes import *\n",
"\n",
"class ZILNLoss(nn.Module):\n",
" r\"\"\"Implementation of the `Zero Inflated LogNormal loss\n",
" <https://arxiv.org/pdf/1912.07753.pdf>`\n",
" \"\"\"\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" def forward(self, input: Tensor, target: Tensor) -> Tensor:\n",
" r\"\"\"\n",
" Parameters\n",
" ----------\n",
" input: Tensor\n",
" input tensor with predictions (not probabilities)\n",
" target: Tensor\n",
" target tensor with the actual classes\n",
"\n",
" Examples\n",
" --------\n",
" >>> import torch\n",
" >>>\n",
" >>> from pytorch_widedeep.losses import ZILNLoss\n",
" >>>\n",
" >>> # REGRESSION\n",
" >>> target = torch.tensor([[0., 1.5]]).view(-1, 1)\n",
" >>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])\n",
" >>> ZILNLoss()(input, target)\n",
" tensor([0.6287, 1.9941])\n",
" \"\"\"\n",
" positive = target>0\n",
" positive = positive.float()\n",
"\n",
" assert input.shape == torch.Size([target.shape[0], 3]), \"Wrong shape of input.\"\n",
" positive_input = input[..., :1]\n",
"\n",
" classification_loss = F.binary_cross_entropy_with_logits(positive_input, positive)\n",
"\n",
" loc = input[..., 1:2]\n",
" scale = torch.maximum(\n",
" F.softplus(input[..., 2:]),\n",
" torch.sqrt(torch.Tensor([torch.finfo(torch.float32).eps])))\n",
" safe_labels = positive * target + (\n",
" 1 - positive) * torch.ones_like(target)\n",
"\n",
" regression_loss = -torch.mean(\n",
" positive * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),\n",
" dim=-1)\n",
"\n",
" return classification_loss + regression_loss"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "d1073cc6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.6287, 1.9941])"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"target = torch.tensor([[0., 1.5]]).view(-1, 1)\n",
"input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])\n",
"ZILNLoss()(input, target)"
]
},
{
"cell_type": "markdown",
"id": "6f91526c",
"metadata": {},
"source": [
"# Keras implementation - original"
]
},
{
"cell_type": "markdown",
"id": "9fc84a5f",
"metadata": {},
"source": [
"* https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal.py"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9ac7bf00",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v1 as tf\n",
"import tensorflow_probability as tfp\n",
"tfd = tfp.distributions\n",
"\n",
"\n",
"def zero_inflated_lognormal_pred(logits: tf.Tensor) -> tf.Tensor:\n",
" \"\"\"Calculates predicted mean of zero inflated lognormal logits.\n",
" Arguments:\n",
" logits: [batch_size, 3] tensor of logits.\n",
" Returns:\n",
" preds: [batch_size, 1] tensor of predicted mean.\n",
" \"\"\"\n",
" logits = tf.convert_to_tensor(logits, dtype=tf.float32)\n",
" positive_probs = tf.keras.backend.sigmoid(logits[..., :1])\n",
" loc = logits[..., 1:2]\n",
" scale = tf.keras.backend.softplus(logits[..., 2:])\n",
" preds = (\n",
" positive_probs *\n",
" tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale)))\n",
" return preds\n",
"\n",
"\n",
"def zero_inflated_lognormal_loss(labels: tf.Tensor,\n",
" logits: tf.Tensor) -> tf.Tensor:\n",
" \"\"\"Computes the zero inflated lognormal loss.\n",
" Usage with tf.keras API:\n",
" ```python\n",
" model = tf.keras.Model(inputs, outputs)\n",
" model.compile('sgd', loss=zero_inflated_lognormal)\n",
" ```\n",
" Arguments:\n",
" labels: True targets, tensor of shape [batch_size, 1].\n",
" logits: Logits of output layer, tensor of shape [batch_size, 3].\n",
" Returns:\n",
" Zero inflated lognormal loss value.\n",
" \"\"\"\n",
" labels = tf.convert_to_tensor(labels, dtype=tf.float32)\n",
" positive = tf.cast(labels > 0, tf.float32)\n",
"\n",
" logits = tf.convert_to_tensor(logits, dtype=tf.float32)\n",
" logits.shape.assert_is_compatible_with(\n",
" tf.TensorShape(labels.shape[:-1].as_list() + [3]))\n",
"\n",
" positive_logits = logits[..., :1]\n",
" classification_loss = tf.keras.losses.binary_crossentropy(\n",
" y_true=positive, y_pred=positive_logits, from_logits=True)\n",
"\n",
" loc = logits[..., 1:2]\n",
" scale = tf.math.maximum(\n",
" tf.keras.backend.softplus(logits[..., 2:]),\n",
" tf.math.sqrt(tf.keras.backend.epsilon()))\n",
" safe_labels = positive * labels + (\n",
" 1 - positive) * tf.keras.backend.ones_like(labels)\n",
" regression_loss = -tf.keras.backend.mean(\n",
" positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),\n",
" axis=-1)\n",
"\n",
" return classification_loss + regression_loss"
]
},
{
"cell_type": "markdown",
"id": "2256b83b",
"metadata": {},
"source": [
"* https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal_test.py"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0ceaf7ac",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from scipy import stats\n",
"import tensorflow.compat.v1 as tf\n",
"\n",
"\n",
"# Absolute error tolerance in asserting array near.\n",
"_ERR_TOL = 1e-6\n",
"\n",
"# softplus function that calculates log(1+exp(x))\n",
"_softplus = lambda x: np.log(1.0 + np.exp(x))\n",
"\n",
"# sigmoid function that calculates 1/(1+exp(-x))\n",
"_sigmoid = lambda x: 1 / (1 + np.exp(-x))\n",
"\n",
"\n",
"class ZeroInflatedLognormalLossTest():\n",
"\n",
" def setUp(self):\n",
" super(ZeroInflatedLognormalLossTest, self).setUp()\n",
" self.logits = np.array([[.1, .2, .3], [.4, .5, .6]])\n",
" self.labels = np.array([[0.], [1.5]])\n",
"\n",
" def zero_inflated_lognormal(self, labels, logits):\n",
" positive_logits = logits[..., :1]\n",
" loss_zero = _softplus(positive_logits)\n",
" loc = logits[..., 1:2]\n",
" scale = np.maximum(\n",
" _softplus(logits[..., 2:]),\n",
" np.sqrt(tf.keras.backend.epsilon()))\n",
" log_prob_non_zero = stats.lognorm.logpdf(\n",
" x=labels, s=scale, loc=0, scale=np.exp(loc))\n",
" loss_non_zero = _softplus(-positive_logits) - log_prob_non_zero\n",
" return np.mean(np.where(labels == 0., loss_zero, loss_non_zero), axis=-1)\n",
"\n",
" def test_loss_value(self):\n",
" expected_loss = self.zero_inflated_lognormal(self.labels, self.logits)\n",
" loss = zero_inflated_lognormal.zero_inflated_lognormal_loss(\n",
" self.labels, self.logits)\n",
" self.assertArrayNear(self.evaluate(loss), expected_loss, _ERR_TOL)"
]
}
],
"metadata": {
"interpreter": {
"hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d"
},
"kernelspec": {
"display_name": "Python 3.8.5 64-bit ('base': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -7,6 +7,57 @@ from pytorch_widedeep.wdtypes import * # noqa: F403
use_cuda = torch.cuda.is_available()
class ZILNLoss(nn.Module):
r"""Implementation of the `Zero Inflated LogNormal loss
<https://arxiv.org/pdf/1912.07753.pdf>`
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor, target: Tensor) -> Tensor:
r"""
Parameters
----------
input: Tensor
input tensor with predictions (not probabilities)
target: Tensor
target tensor with the actual classes
Examples
--------
>>> import torch
>>>
>>> from pytorch_widedeep.losses import ZILNLoss
>>>
>>> # REGRESSION
>>> target = torch.tensor([[0., 1.5]]).view(-1, 1)
>>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])
>>> ZILNLoss()(input, target)
tensor([0.6287, 1.9941])
"""
positive = target>0
positive = positive.float()
assert input.shape == torch.Size([target.shape[0], 3]), "Wrong shape of input."
positive_input = input[..., :1]
classification_loss = F.binary_cross_entropy_with_logits(positive_input, positive)
loc = input[..., 1:2]
scale = torch.maximum(
F.softplus(input[..., 2:]),
torch.sqrt(torch.Tensor([torch.finfo(torch.float32).eps])))
safe_labels = positive * target + (
1 - positive) * torch.ones_like(target)
regression_loss = -torch.mean(
positive * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),
dim=-1)
return classification_loss + regression_loss
class FocalLoss(nn.Module):
r"""Implementation of the `focal loss
<https://arxiv.org/pdf/1708.02002.pdf>`_ for both binary and
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册