From 788f63f1b06a23c34c82218348d5f12febcb3bd7 Mon Sep 17 00:00:00 2001 From: Pavol Mulinka Date: Tue, 16 Nov 2021 14:09:59 +0100 Subject: [PATCH] initial ZILN loss commit --- examples/14_Losses-temp.ipynb | 298 ++++++++++++++++++++++++++++++++++ pytorch_widedeep/losses.py | 51 ++++++ 2 files changed, 349 insertions(+) create mode 100644 examples/14_Losses-temp.ipynb diff --git a/examples/14_Losses-temp.ipynb b/examples/14_Losses-temp.ipynb new file mode 100644 index 0000000..21e43fb --- /dev/null +++ b/examples/14_Losses-temp.ipynb @@ -0,0 +1,298 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6658e02e", + "metadata": {}, + "source": [ + "# Losses" + ] + }, + { + "cell_type": "markdown", + "id": "639ad78b", + "metadata": {}, + "source": [ + "## Initial imports" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d40952c1", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "\n", + "\n", + "# increase displayed columns in jupyter notebook\n", + "pd.set_option(\"display.max_columns\", 200)\n", + "pd.set_option(\"display.max_rows\", 300)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "9864678c", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from pytorch_widedeep.wdtypes import *\n", + "\n", + "class ZILNLoss(nn.Module):\n", + " r\"\"\"Implementation of the `Zero Inflated LogNormal loss\n", + " `\n", + " \"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " def forward(self, input: Tensor, target: Tensor) -> Tensor:\n", + " r\"\"\"\n", + " Parameters\n", + " ----------\n", + " input: Tensor\n", + " input tensor with predictions (not probabilities)\n", + " target: Tensor\n", + " target tensor with the actual classes\n", + "\n", + " Examples\n", + " --------\n", + " >>> import torch\n", + " >>>\n", + " >>> from pytorch_widedeep.losses import ZILNLoss\n", + " >>>\n", + " >>> # REGRESSION\n", + " >>> target = torch.tensor([[0., 1.5]]).view(-1, 1)\n", + " >>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])\n", + " >>> ZILNLoss()(input, target)\n", + " tensor([0.6287, 1.9941])\n", + " \"\"\"\n", + " positive = target>0\n", + " positive = positive.float()\n", + "\n", + " assert input.shape == torch.Size([target.shape[0], 3]), \"Wrong shape of input.\"\n", + " positive_input = input[..., :1]\n", + "\n", + " classification_loss = F.binary_cross_entropy_with_logits(positive_input, positive)\n", + "\n", + " loc = input[..., 1:2]\n", + " scale = torch.maximum(\n", + " F.softplus(input[..., 2:]),\n", + " torch.sqrt(torch.Tensor([torch.finfo(torch.float32).eps])))\n", + " safe_labels = positive * target + (\n", + " 1 - positive) * torch.ones_like(target)\n", + "\n", + " regression_loss = -torch.mean(\n", + " positive * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),\n", + " dim=-1)\n", + "\n", + " return classification_loss + regression_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "d1073cc6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.6287, 1.9941])" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "target = torch.tensor([[0., 1.5]]).view(-1, 1)\n", + "input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])\n", + "ZILNLoss()(input, target)" + ] + }, + { + "cell_type": "markdown", + "id": "6f91526c", + "metadata": {}, + "source": [ + "# Keras implementation - original" + ] + }, + { + "cell_type": "markdown", + "id": "9fc84a5f", + "metadata": {}, + "source": [ + "* https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal.py" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9ac7bf00", + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow.compat.v1 as tf\n", + "import tensorflow_probability as tfp\n", + "tfd = tfp.distributions\n", + "\n", + "\n", + "def zero_inflated_lognormal_pred(logits: tf.Tensor) -> tf.Tensor:\n", + " \"\"\"Calculates predicted mean of zero inflated lognormal logits.\n", + " Arguments:\n", + " logits: [batch_size, 3] tensor of logits.\n", + " Returns:\n", + " preds: [batch_size, 1] tensor of predicted mean.\n", + " \"\"\"\n", + " logits = tf.convert_to_tensor(logits, dtype=tf.float32)\n", + " positive_probs = tf.keras.backend.sigmoid(logits[..., :1])\n", + " loc = logits[..., 1:2]\n", + " scale = tf.keras.backend.softplus(logits[..., 2:])\n", + " preds = (\n", + " positive_probs *\n", + " tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale)))\n", + " return preds\n", + "\n", + "\n", + "def zero_inflated_lognormal_loss(labels: tf.Tensor,\n", + " logits: tf.Tensor) -> tf.Tensor:\n", + " \"\"\"Computes the zero inflated lognormal loss.\n", + " Usage with tf.keras API:\n", + " ```python\n", + " model = tf.keras.Model(inputs, outputs)\n", + " model.compile('sgd', loss=zero_inflated_lognormal)\n", + " ```\n", + " Arguments:\n", + " labels: True targets, tensor of shape [batch_size, 1].\n", + " logits: Logits of output layer, tensor of shape [batch_size, 3].\n", + " Returns:\n", + " Zero inflated lognormal loss value.\n", + " \"\"\"\n", + " labels = tf.convert_to_tensor(labels, dtype=tf.float32)\n", + " positive = tf.cast(labels > 0, tf.float32)\n", + "\n", + " logits = tf.convert_to_tensor(logits, dtype=tf.float32)\n", + " logits.shape.assert_is_compatible_with(\n", + " tf.TensorShape(labels.shape[:-1].as_list() + [3]))\n", + "\n", + " positive_logits = logits[..., :1]\n", + " classification_loss = tf.keras.losses.binary_crossentropy(\n", + " y_true=positive, y_pred=positive_logits, from_logits=True)\n", + "\n", + " loc = logits[..., 1:2]\n", + " scale = tf.math.maximum(\n", + " tf.keras.backend.softplus(logits[..., 2:]),\n", + " tf.math.sqrt(tf.keras.backend.epsilon()))\n", + " safe_labels = positive * labels + (\n", + " 1 - positive) * tf.keras.backend.ones_like(labels)\n", + " regression_loss = -tf.keras.backend.mean(\n", + " positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),\n", + " axis=-1)\n", + "\n", + " return classification_loss + regression_loss" + ] + }, + { + "cell_type": "markdown", + "id": "2256b83b", + "metadata": {}, + "source": [ + "* https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal_test.py" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0ceaf7ac", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from scipy import stats\n", + "import tensorflow.compat.v1 as tf\n", + "\n", + "\n", + "# Absolute error tolerance in asserting array near.\n", + "_ERR_TOL = 1e-6\n", + "\n", + "# softplus function that calculates log(1+exp(x))\n", + "_softplus = lambda x: np.log(1.0 + np.exp(x))\n", + "\n", + "# sigmoid function that calculates 1/(1+exp(-x))\n", + "_sigmoid = lambda x: 1 / (1 + np.exp(-x))\n", + "\n", + "\n", + "class ZeroInflatedLognormalLossTest():\n", + "\n", + " def setUp(self):\n", + " super(ZeroInflatedLognormalLossTest, self).setUp()\n", + " self.logits = np.array([[.1, .2, .3], [.4, .5, .6]])\n", + " self.labels = np.array([[0.], [1.5]])\n", + "\n", + " def zero_inflated_lognormal(self, labels, logits):\n", + " positive_logits = logits[..., :1]\n", + " loss_zero = _softplus(positive_logits)\n", + " loc = logits[..., 1:2]\n", + " scale = np.maximum(\n", + " _softplus(logits[..., 2:]),\n", + " np.sqrt(tf.keras.backend.epsilon()))\n", + " log_prob_non_zero = stats.lognorm.logpdf(\n", + " x=labels, s=scale, loc=0, scale=np.exp(loc))\n", + " loss_non_zero = _softplus(-positive_logits) - log_prob_non_zero\n", + " return np.mean(np.where(labels == 0., loss_zero, loss_non_zero), axis=-1)\n", + "\n", + " def test_loss_value(self):\n", + " expected_loss = self.zero_inflated_lognormal(self.labels, self.logits)\n", + " loss = zero_inflated_lognormal.zero_inflated_lognormal_loss(\n", + " self.labels, self.logits)\n", + " self.assertArrayNear(self.evaluate(loss), expected_loss, _ERR_TOL)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d" + }, + "kernelspec": { + "display_name": "Python 3.8.5 64-bit ('base': conda)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pytorch_widedeep/losses.py b/pytorch_widedeep/losses.py index 84a8942..e33000f 100644 --- a/pytorch_widedeep/losses.py +++ b/pytorch_widedeep/losses.py @@ -7,6 +7,57 @@ from pytorch_widedeep.wdtypes import * # noqa: F403 use_cuda = torch.cuda.is_available() +class ZILNLoss(nn.Module): + r"""Implementation of the `Zero Inflated LogNormal loss + ` + """ + + def __init__(self): + super().__init__() + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + r""" + Parameters + ---------- + input: Tensor + input tensor with predictions (not probabilities) + target: Tensor + target tensor with the actual classes + + Examples + -------- + >>> import torch + >>> + >>> from pytorch_widedeep.losses import ZILNLoss + >>> + >>> # REGRESSION + >>> target = torch.tensor([[0., 1.5]]).view(-1, 1) + >>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]]) + >>> ZILNLoss()(input, target) + tensor([0.6287, 1.9941]) + """ + positive = target>0 + positive = positive.float() + + assert input.shape == torch.Size([target.shape[0], 3]), "Wrong shape of input." + positive_input = input[..., :1] + + classification_loss = F.binary_cross_entropy_with_logits(positive_input, positive) + + loc = input[..., 1:2] + scale = torch.maximum( + F.softplus(input[..., 2:]), + torch.sqrt(torch.Tensor([torch.finfo(torch.float32).eps]))) + safe_labels = positive * target + ( + 1 - positive) * torch.ones_like(target) + + regression_loss = -torch.mean( + positive * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(safe_labels), + dim=-1) + + return classification_loss + regression_loss + + class FocalLoss(nn.Module): r"""Implementation of the `focal loss `_ for both binary and -- GitLab