{ "cells": [ { "cell_type": "markdown", "id": "6658e02e", "metadata": {}, "source": [ "# Losses" ] }, { "cell_type": "markdown", "id": "639ad78b", "metadata": {}, "source": [ "## Initial imports" ] }, { "cell_type": "code", "execution_count": 71, "id": "d40952c1", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "\n", "# increase displayed columns in jupyter notebook\n", "pd.set_option(\"display.max_columns\", 200)\n", "pd.set_option(\"display.max_rows\", 300)" ] }, { "cell_type": "markdown", "id": "54d7d143", "metadata": {}, "source": [ "# Pytorch implementation\n", "* reduction=\"none\" in binary_cross_entropy_with_logits gives same result for given testing usecase as authors implementation https://pytorch.org/docs/stable/generated/torch.nn.functional.binary_cross_entropy_with_logits.html\n", " * hovewer keras BCE uses auto reduction by default https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy which changes according to usecase : so wtf? https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction\n", "* keras also return flattened BCE result\n", "* keras needs loss per sample and it averages it in the backend: https://keras.io/api/losses/" ] }, { "cell_type": "code", "execution_count": 83, "id": "9864678c", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from pytorch_widedeep.wdtypes import *\n", "\n", "\n", "def _predict_ziln(preds: Tensor) -> Tensor:\n", " \"\"\"Calculates predicted mean of zero inflated lognormal logits.\n", " Adjusted implementaion of `code\n", " `\n", " Arguments:\n", " preds: [batch_size, 3] tensor of logits.\n", " Returns:\n", " ziln_preds: [batch_size, 1] tensor of predicted mean.\n", " \"\"\"\n", " positive_probs = torch.sigmoid(preds[..., :1])\n", " loc = preds[..., 1:2]\n", " scale = F.softplus(preds[..., 2:])\n", " ziln_preds = (\n", " positive_probs *\n", " torch.exp(loc + 0.5 * torch.square(scale)))\n", " return ziln_preds\n", "\n", "\n", "class ZILNLoss(nn.Module):\n", " r\"\"\"Adjusted implementation of the `Zero Inflated LogNormal loss\n", " ` and its `code\n", " `\n", " \"\"\"\n", "\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def forward(self, input: Tensor, target: Tensor) -> Tensor:\n", " r\"\"\"\n", " Parameters\n", " ----------\n", " input: Tensor\n", " input tensor with predictions (not probabilities)\n", " target: Tensor\n", " target tensor with the actual classes\n", "\n", " Examples\n", " --------\n", " >>> import torch\n", " >>>\n", " >>> from pytorch_widedeep.losses import ZILNLoss\n", " >>>\n", " >>> # REGRESSION\n", " >>> target = torch.tensor([[0., 1.5]]).view(-1, 1)\n", " >>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])\n", " >>> ZILNLoss()(input, target)\n", " tensor([0.6287, 1.9941])\n", " \"\"\"\n", " positive = target>0\n", " positive = positive.float()\n", "\n", " assert input.shape == torch.Size([target.shape[0], 3]), \"Wrong shape of input.\"\n", " positive_input = input[..., :1]\n", "\n", " classification_loss = F.binary_cross_entropy_with_logits(positive_input, positive, reduction=\"none\").flatten() \n", "\n", " loc = input[..., 1:2]\n", " scale = torch.maximum(\n", " F.softplus(input[..., 2:]),\n", " torch.sqrt(torch.Tensor([torch.finfo(torch.float32).eps])))\n", " safe_labels = positive * target + (\n", " 1 - positive) * torch.ones_like(target)\n", "\n", " regression_loss = -torch.mean(\n", " positive * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),\n", " dim=-1)\n", "\n", " return torch.mean(classification_loss + regression_loss)" ] }, { "cell_type": "code", "execution_count": 84, "id": "d1073cc6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.3114)" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "use_cuda = torch.cuda.is_available()\n", "target = torch.tensor([[0., 1.5]]).view(-1, 1)\n", "input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])\n", "ZILNLoss()(input, target)" ] }, { "cell_type": "markdown", "id": "6f91526c", "metadata": {}, "source": [ "# Keras implementation - original" ] }, { "cell_type": "markdown", "id": "9fc84a5f", "metadata": {}, "source": [ "* https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal.py" ] }, { "cell_type": "code", "execution_count": 67, "id": "9ac7bf00", "metadata": {}, "outputs": [], "source": [ "import tensorflow.compat.v1 as tf\n", "import tensorflow_probability as tfp\n", "tfd = tfp.distributions\n", "\n", "\n", "def zero_inflated_lognormal_pred(logits: tf.Tensor) -> tf.Tensor:\n", " \"\"\"Calculates predicted mean of zero inflated lognormal logits.\n", " Arguments:\n", " logits: [batch_size, 3] tensor of logits.\n", " Returns:\n", " preds: [batch_size, 1] tensor of predicted mean.\n", " \"\"\"\n", " logits = tf.convert_to_tensor(logits, dtype=tf.float32)\n", " positive_probs = tf.keras.backend.sigmoid(logits[..., :1])\n", " loc = logits[..., 1:2]\n", " scale = tf.keras.backend.softplus(logits[..., 2:])\n", " preds = (\n", " positive_probs *\n", " tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale)))\n", " return preds\n", "\n", "\n", "def zero_inflated_lognormal_loss(labels: tf.Tensor,\n", " logits: tf.Tensor) -> tf.Tensor:\n", " \"\"\"Computes the zero inflated lognormal loss.\n", " Usage with tf.keras API:\n", " ```python\n", " model = tf.keras.Model(inputs, outputs)\n", " model.compile('sgd', loss=zero_inflated_lognormal)\n", " ```\n", " Arguments:\n", " labels: True targets, tensor of shape [batch_size, 1].\n", " logits: Logits of output layer, tensor of shape [batch_size, 3].\n", " Returns:\n", " Zero inflated lognormal loss value.\n", " \"\"\"\n", " labels = tf.convert_to_tensor(labels, dtype=tf.float32)\n", " positive = tf.cast(labels > 0, tf.float32)\n", "\n", " logits = tf.convert_to_tensor(logits, dtype=tf.float32)\n", " logits.shape.assert_is_compatible_with(\n", " tf.TensorShape(labels.shape[:-1].as_list() + [3]))\n", "\n", " positive_logits = logits[..., :1]\n", " classification_loss = tf.keras.losses.binary_crossentropy(\n", " y_true=positive, y_pred=positive_logits, from_logits=True)\n", "\n", " loc = logits[..., 1:2]\n", " scale = tf.math.maximum(\n", " tf.keras.backend.softplus(logits[..., 2:]),\n", " tf.math.sqrt(tf.keras.backend.epsilon()))\n", " safe_labels = positive * labels + (\n", " 1 - positive) * tf.keras.backend.ones_like(labels)\n", " regression_loss = -tf.keras.backend.mean(\n", " positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels),\n", " axis=-1)\n", "\n", " return classification_loss + regression_loss" ] }, { "cell_type": "code", "execution_count": 68, "id": "25668bf0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target = tf.reshape(tf.constant([[0., 1.5]]), [-1, 1])\n", "input = tf.constant([[.1, .2, .3], [.4, .5, .6]])\n", "zero_inflated_lognormal_loss(target, input)" ] }, { "cell_type": "code", "execution_count": 69, "id": "374f5d40", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "zero_inflated_lognormal_pred(input)" ] } ], "metadata": { "interpreter": { "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d" }, "kernelspec": { "display_name": "Python 3.8.5 64-bit ('base': conda)", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }