{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import os\n", "import tempfile\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "\n", "import tensorflow_recommenders as tfrs\n", "\n", "plt.style.use('seaborn-whitegrid')" ] }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "ratings = tfds.load(\"movielens/100k-ratings\", split=\"train\")\n", "movies = tfds.load(\"movielens/100k-movies\", split=\"train\")\n", "\n", "ratings = ratings.map(lambda x: {\n", " \"movie_title\": x[\"movie_title\"],\n", " \"user_id\": x[\"user_id\"],\n", " \"timestamp\": x[\"timestamp\"],\n", "})\n", "movies = movies.map(lambda x: x[\"movie_title\"])" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "timestamps = np.concatenate(list(ratings.map(lambda x: x[\"timestamp\"]).batch(100)))\n", "\n", "max_timestamp = timestamps.max()\n", "min_timestamp = timestamps.min()\n", "\n", "timestamp_buckets = np.linspace(\n", " min_timestamp, max_timestamp, num=1000,\n", ")\n", "\n", "unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))\n", "unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(\n", " lambda x: x[\"user_id\"]))))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "class UserModel(tf.keras.Model):\n", "\n", " def __init__(self):\n", " super().__init__()\n", "\n", " self.user_embedding = tf.keras.Sequential([\n", " tf.keras.layers.experimental.preprocessing.StringLookup(\n", " vocabulary=unique_user_ids, mask_token=None),\n", " tf.keras.layers.Embedding(len(unique_user_ids) + 1, 32),\n", " ])\n", " self.timestamp_embedding = tf.keras.Sequential([\n", " tf.keras.layers.experimental.preprocessing.Discretization(timestamp_buckets.tolist()),\n", " tf.keras.layers.Embedding(len(timestamp_buckets) + 1, 32),\n", " ])\n", " self.normalized_timestamp = tf.keras.layers.experimental.preprocessing.Normalization()\n", "\n", " self.normalized_timestamp.adapt(timestamps)\n", "\n", " def call(self, inputs):\n", " # Take the input dictionary, pass it through each input layer,\n", " # and concatenate the result.\n", " return tf.concat([\n", " self.user_embedding(inputs[\"user_id\"]),\n", " self.timestamp_embedding(inputs[\"timestamp\"]),\n", " self.normalized_timestamp(inputs[\"timestamp\"]),\n", " ], axis=1)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [], "source": [ "class QueryModel(tf.keras.Model):\n", " \"\"\"Model for encoding user queries.\"\"\"\n", "\n", " def __init__(self, layer_sizes):\n", " \"\"\"Model for encoding user queries.\n", "\n", " Args:\n", " layer_sizes:\n", " A list of integers where the i-th entry represents the number of units\n", " the i-th layer contains.\n", " \"\"\"\n", " super().__init__()\n", "\n", " # We first use the user model for generating embeddings.\n", " self.embedding_model = UserModel()\n", "\n", " # Then construct the layers.\n", " self.dense_layers = tf.keras.Sequential()\n", "\n", " # Use the ReLU activation for all but the last layer.\n", " for layer_size in layer_sizes[:-1]:\n", " self.dense_layers.add(tf.keras.layers.Dense(layer_size, activation=\"relu\"))\n", "\n", " # No activation for the last layer.\n", " for layer_size in layer_sizes[-1:]:\n", " self.dense_layers.add(tf.keras.layers.Dense(layer_size))\n", "\n", " def call(self, inputs):\n", " feature_embedding = self.embedding_model(inputs)\n", " return self.dense_layers(feature_embedding)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [], "source": [ "class MovieModel(tf.keras.Model):\n", "\n", " def __init__(self):\n", " super().__init__()\n", "\n", " max_tokens = 10_000\n", "\n", " self.title_embedding = tf.keras.Sequential([\n", " tf.keras.layers.experimental.preprocessing.StringLookup(\n", " vocabulary=unique_movie_titles,mask_token=None),\n", " tf.keras.layers.Embedding(len(unique_movie_titles) + 1, 32)\n", " ])\n", "\n", " self.title_vectorizer = tf.keras.layers.experimental.preprocessing.TextVectorization(\n", " max_tokens=max_tokens)\n", "\n", " self.title_text_embedding = tf.keras.Sequential([\n", " self.title_vectorizer,\n", " tf.keras.layers.Embedding(max_tokens, 32, mask_zero=True),\n", " tf.keras.layers.GlobalAveragePooling1D(),\n", " ])\n", "\n", " self.title_vectorizer.adapt(movies)\n", "\n", " def call(self, titles):\n", " return tf.concat([\n", " self.title_embedding(titles),\n", " self.title_text_embedding(titles),\n", " ], axis=1)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 7, "outputs": [], "source": [ "class CandidateModel(tf.keras.Model):\n", " \"\"\"Model for encoding movies.\"\"\"\n", "\n", " def __init__(self, layer_sizes):\n", " \"\"\"Model for encoding movies.\n", "\n", " Args:\n", " layer_sizes:\n", " A list of integers where the i-th entry represents the number of units\n", " the i-th layer contains.\n", " \"\"\"\n", " super().__init__()\n", "\n", " self.embedding_model = MovieModel()\n", "\n", " # Then construct the layers.\n", " self.dense_layers = tf.keras.Sequential()\n", "\n", " # Use the ReLU activation for all but the last layer.\n", " for layer_size in layer_sizes[:-1]:\n", " self.dense_layers.add(tf.keras.layers.Dense(layer_size, activation=\"relu\"))\n", "\n", " # No activation for the last layer.\n", " for layer_size in layer_sizes[-1:]:\n", " self.dense_layers.add(tf.keras.layers.Dense(layer_size))\n", "\n", " def call(self, inputs):\n", " feature_embedding = self.embedding_model(inputs)\n", " return self.dense_layers(feature_embedding)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "class MovielensModel(tfrs.models.Model):\n", "\n", " def __init__(self, layer_sizes):\n", " super().__init__()\n", " self.query_model = QueryModel(layer_sizes)\n", " self.candidate_model = CandidateModel(layer_sizes)\n", " self.task = tfrs.tasks.Retrieval(\n", " metrics=tfrs.metrics.FactorizedTopK(\n", " candidates=movies.batch(128).map(self.candidate_model),\n", " ),\n", " )\n", "\n", " def compute_loss(self, features, training=False):\n", " # We only pass the user id and timestamp features into the query model. This\n", " # is to ensure that the training inputs would have the same keys as the\n", " # query inputs. Otherwise the discrepancy in input structure would cause an\n", " # error when loading the query model after saving it.\n", " query_embeddings = self.query_model({\n", " \"user_id\": features[\"user_id\"],\n", " \"timestamp\": features[\"timestamp\"],\n", " })\n", " movie_embeddings = self.candidate_model(features[\"movie_title\"])\n", "\n", " return self.task(\n", " query_embeddings, movie_embeddings)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 9, "outputs": [], "source": [ "tf.random.set_seed(42)\n", "shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)\n", "\n", "train = shuffled.take(80_000)\n", "test = shuffled.skip(80_000).take(20_000)\n", "\n", "cached_train = train.shuffle(100_000).batch(2048)\n", "cached_test = test.batch(4096).cache()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 10, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "WARNING:tensorflow:From /home/mao/anaconda3/envs/tf2.5/lib/python3.6/site-packages/tensorflow/python/ops/parallel_for/pfor.py:2382: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /home/mao/anaconda3/envs/tf2.5/lib/python3.6/site-packages/tensorflow/python/ops/parallel_for/pfor.py:2382: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "40/40 - 11s - factorized_top_k/top_1_categorical_accuracy: 0.0089 - factorized_top_k/top_5_categorical_accuracy: 0.0206 - factorized_top_k/top_10_categorical_accuracy: 0.0312 - factorized_top_k/top_50_categorical_accuracy: 0.0935 - factorized_top_k/top_100_categorical_accuracy: 0.1611 - loss: 563.2299 - regularization_loss: 0.0000e+00 - total_loss: 563.2299\n", "Epoch 2/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0035 - factorized_top_k/top_5_categorical_accuracy: 0.0175 - factorized_top_k/top_10_categorical_accuracy: 0.0333 - factorized_top_k/top_50_categorical_accuracy: 0.1392 - factorized_top_k/top_100_categorical_accuracy: 0.2560 - loss: 556.3237 - regularization_loss: 0.0000e+00 - total_loss: 556.3237\n", "Epoch 3/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0025 - factorized_top_k/top_5_categorical_accuracy: 0.0195 - factorized_top_k/top_10_categorical_accuracy: 0.0391 - factorized_top_k/top_50_categorical_accuracy: 0.1716 - factorized_top_k/top_100_categorical_accuracy: 0.3043 - loss: 554.7384 - regularization_loss: 0.0000e+00 - total_loss: 554.7384\n", "Epoch 4/20\n", "40/40 - 9s - factorized_top_k/top_1_categorical_accuracy: 0.0014 - factorized_top_k/top_5_categorical_accuracy: 0.0216 - factorized_top_k/top_10_categorical_accuracy: 0.0457 - factorized_top_k/top_50_categorical_accuracy: 0.1957 - factorized_top_k/top_100_categorical_accuracy: 0.3319 - loss: 529.2206 - regularization_loss: 0.0000e+00 - total_loss: 529.2206\n", "Epoch 5/20\n", "40/40 - 9s - factorized_top_k/top_1_categorical_accuracy: 0.0017 - factorized_top_k/top_5_categorical_accuracy: 0.0255 - factorized_top_k/top_10_categorical_accuracy: 0.0516 - factorized_top_k/top_50_categorical_accuracy: 0.2121 - factorized_top_k/top_100_categorical_accuracy: 0.3540 - loss: 528.1415 - regularization_loss: 0.0000e+00 - total_loss: 528.1415\n", "Epoch 6/20\n", "40/40 - 9s - factorized_top_k/top_1_categorical_accuracy: 0.0015 - factorized_top_k/top_5_categorical_accuracy: 0.0275 - factorized_top_k/top_10_categorical_accuracy: 0.0564 - factorized_top_k/top_50_categorical_accuracy: 0.2261 - factorized_top_k/top_100_categorical_accuracy: 0.3714 - loss: 519.7426 - regularization_loss: 0.0000e+00 - total_loss: 519.7426\n", "Epoch 7/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0014 - factorized_top_k/top_5_categorical_accuracy: 0.0311 - factorized_top_k/top_10_categorical_accuracy: 0.0611 - factorized_top_k/top_50_categorical_accuracy: 0.2389 - factorized_top_k/top_100_categorical_accuracy: 0.3863 - loss: 514.4518 - regularization_loss: 0.0000e+00 - total_loss: 514.4518\n", "Epoch 8/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0012 - factorized_top_k/top_5_categorical_accuracy: 0.0334 - factorized_top_k/top_10_categorical_accuracy: 0.0659 - factorized_top_k/top_50_categorical_accuracy: 0.2502 - factorized_top_k/top_100_categorical_accuracy: 0.4024 - loss: 508.0785 - regularization_loss: 0.0000e+00 - total_loss: 508.0785\n", "Epoch 9/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0012 - factorized_top_k/top_5_categorical_accuracy: 0.0345 - factorized_top_k/top_10_categorical_accuracy: 0.0684 - factorized_top_k/top_50_categorical_accuracy: 0.2619 - factorized_top_k/top_100_categorical_accuracy: 0.4155 - loss: 508.0655 - regularization_loss: 0.0000e+00 - total_loss: 508.0655\n", "Epoch 10/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0016 - factorized_top_k/top_5_categorical_accuracy: 0.0373 - factorized_top_k/top_10_categorical_accuracy: 0.0725 - factorized_top_k/top_50_categorical_accuracy: 0.2712 - factorized_top_k/top_100_categorical_accuracy: 0.4290 - loss: 492.9934 - regularization_loss: 0.0000e+00 - total_loss: 492.9934\n", "Epoch 11/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0018 - factorized_top_k/top_5_categorical_accuracy: 0.0395 - factorized_top_k/top_10_categorical_accuracy: 0.0769 - factorized_top_k/top_50_categorical_accuracy: 0.2790 - factorized_top_k/top_100_categorical_accuracy: 0.4393 - loss: 508.4373 - regularization_loss: 0.0000e+00 - total_loss: 508.4373\n", "Epoch 12/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0014 - factorized_top_k/top_5_categorical_accuracy: 0.0405 - factorized_top_k/top_10_categorical_accuracy: 0.0790 - factorized_top_k/top_50_categorical_accuracy: 0.2868 - factorized_top_k/top_100_categorical_accuracy: 0.4497 - loss: 509.4974 - regularization_loss: 0.0000e+00 - total_loss: 509.4974\n", "Epoch 13/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0017 - factorized_top_k/top_5_categorical_accuracy: 0.0426 - factorized_top_k/top_10_categorical_accuracy: 0.0816 - factorized_top_k/top_50_categorical_accuracy: 0.2955 - factorized_top_k/top_100_categorical_accuracy: 0.4575 - loss: 491.1711 - regularization_loss: 0.0000e+00 - total_loss: 491.1711\n", "Epoch 14/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0017 - factorized_top_k/top_5_categorical_accuracy: 0.0440 - factorized_top_k/top_10_categorical_accuracy: 0.0846 - factorized_top_k/top_50_categorical_accuracy: 0.3013 - factorized_top_k/top_100_categorical_accuracy: 0.4663 - loss: 487.5421 - regularization_loss: 0.0000e+00 - total_loss: 487.5421\n", "Epoch 15/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0018 - factorized_top_k/top_5_categorical_accuracy: 0.0453 - factorized_top_k/top_10_categorical_accuracy: 0.0868 - factorized_top_k/top_50_categorical_accuracy: 0.3078 - factorized_top_k/top_100_categorical_accuracy: 0.4727 - loss: 479.2850 - regularization_loss: 0.0000e+00 - total_loss: 479.2850\n", "Epoch 16/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0018 - factorized_top_k/top_5_categorical_accuracy: 0.0461 - factorized_top_k/top_10_categorical_accuracy: 0.0896 - factorized_top_k/top_50_categorical_accuracy: 0.3146 - factorized_top_k/top_100_categorical_accuracy: 0.4787 - loss: 452.3918 - regularization_loss: 0.0000e+00 - total_loss: 452.3918\n", "Epoch 17/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0020 - factorized_top_k/top_5_categorical_accuracy: 0.0466 - factorized_top_k/top_10_categorical_accuracy: 0.0920 - factorized_top_k/top_50_categorical_accuracy: 0.3180 - factorized_top_k/top_100_categorical_accuracy: 0.4858 - loss: 509.6038 - regularization_loss: 0.0000e+00 - total_loss: 509.6038\n", "Epoch 18/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0019 - factorized_top_k/top_5_categorical_accuracy: 0.0480 - factorized_top_k/top_10_categorical_accuracy: 0.0939 - factorized_top_k/top_50_categorical_accuracy: 0.3235 - factorized_top_k/top_100_categorical_accuracy: 0.4894 - loss: 498.3298 - regularization_loss: 0.0000e+00 - total_loss: 498.3298\n", "Epoch 19/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0019 - factorized_top_k/top_5_categorical_accuracy: 0.0487 - factorized_top_k/top_10_categorical_accuracy: 0.0953 - factorized_top_k/top_50_categorical_accuracy: 0.3283 - factorized_top_k/top_100_categorical_accuracy: 0.4951 - loss: 482.3253 - regularization_loss: 0.0000e+00 - total_loss: 482.3253\n", "Epoch 20/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0018 - factorized_top_k/top_5_categorical_accuracy: 0.0500 - factorized_top_k/top_10_categorical_accuracy: 0.0971 - factorized_top_k/top_50_categorical_accuracy: 0.3322 - factorized_top_k/top_100_categorical_accuracy: 0.4993 - loss: 474.4318 - regularization_loss: 0.0000e+00 - total_loss: 474.4318\n" ] } ], "source": [ "num_epochs = 20\n", "\n", "model = MovielensModel([32])\n", "model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))\n", "\n", "one_layer_history = model.fit(\n", " cached_train,\n", " epochs=num_epochs,\n", " verbose=2)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 11, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "5/5 [==============================] - 3s 412ms/step - factorized_top_k/top_1_categorical_accuracy: 4.0000e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0037 - factorized_top_k/top_10_categorical_accuracy: 0.0100 - factorized_top_k/top_50_categorical_accuracy: 0.1098 - factorized_top_k/top_100_categorical_accuracy: 0.2492 - loss: 32081.5990 - regularization_loss: 0.0000e+00 - total_loss: 32081.5990\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n" ] }, { "data": { "text/plain": "{'factorized_top_k/top_1_categorical_accuracy': 0.00039999998989515007,\n 'factorized_top_k/top_5_categorical_accuracy': 0.003650000086054206,\n 'factorized_top_k/top_10_categorical_accuracy': 0.009999999776482582,\n 'factorized_top_k/top_50_categorical_accuracy': 0.10984999686479568,\n 'factorized_top_k/top_100_categorical_accuracy': 0.2492000013589859,\n 'loss': 29180.42578125,\n 'regularization_loss': 0,\n 'total_loss': 29180.42578125}" }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "model.evaluate(cached_test, return_dict=True)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 12, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "40/40 - 11s - factorized_top_k/top_1_categorical_accuracy: 0.0204 - factorized_top_k/top_5_categorical_accuracy: 0.0302 - factorized_top_k/top_10_categorical_accuracy: 0.0393 - factorized_top_k/top_50_categorical_accuracy: 0.0873 - factorized_top_k/top_100_categorical_accuracy: 0.1328 - loss: 597.2026 - regularization_loss: 0.0000e+00 - total_loss: 597.2026\n", "Epoch 2/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0061 - factorized_top_k/top_5_categorical_accuracy: 0.0166 - factorized_top_k/top_10_categorical_accuracy: 0.0255 - factorized_top_k/top_50_categorical_accuracy: 0.0856 - factorized_top_k/top_100_categorical_accuracy: 0.1573 - loss: 584.0324 - regularization_loss: 0.0000e+00 - total_loss: 584.0324\n", "Epoch 3/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0097 - factorized_top_k/top_10_categorical_accuracy: 0.0181 - factorized_top_k/top_50_categorical_accuracy: 0.0809 - factorized_top_k/top_100_categorical_accuracy: 0.1549 - loss: 582.6514 - regularization_loss: 0.0000e+00 - total_loss: 582.6514\n", "Epoch 4/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0061 - factorized_top_k/top_5_categorical_accuracy: 0.0176 - factorized_top_k/top_10_categorical_accuracy: 0.0285 - factorized_top_k/top_50_categorical_accuracy: 0.0981 - factorized_top_k/top_100_categorical_accuracy: 0.1757 - loss: 579.3187 - regularization_loss: 0.0000e+00 - total_loss: 579.3187\n", "Epoch 5/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0036 - factorized_top_k/top_5_categorical_accuracy: 0.0131 - factorized_top_k/top_10_categorical_accuracy: 0.0235 - factorized_top_k/top_50_categorical_accuracy: 0.0969 - factorized_top_k/top_100_categorical_accuracy: 0.1864 - loss: 554.9879 - regularization_loss: 0.0000e+00 - total_loss: 554.9879\n", "Epoch 6/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0032 - factorized_top_k/top_5_categorical_accuracy: 0.0127 - factorized_top_k/top_10_categorical_accuracy: 0.0242 - factorized_top_k/top_50_categorical_accuracy: 0.1050 - factorized_top_k/top_100_categorical_accuracy: 0.2002 - loss: 553.6580 - regularization_loss: 0.0000e+00 - total_loss: 553.6580\n", "Epoch 7/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0044 - factorized_top_k/top_5_categorical_accuracy: 0.0152 - factorized_top_k/top_10_categorical_accuracy: 0.0279 - factorized_top_k/top_50_categorical_accuracy: 0.1161 - factorized_top_k/top_100_categorical_accuracy: 0.2174 - loss: 548.2044 - regularization_loss: 0.0000e+00 - total_loss: 548.2044\n", "Epoch 8/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0025 - factorized_top_k/top_5_categorical_accuracy: 0.0135 - factorized_top_k/top_10_categorical_accuracy: 0.0265 - factorized_top_k/top_50_categorical_accuracy: 0.1223 - factorized_top_k/top_100_categorical_accuracy: 0.2304 - loss: 556.7844 - regularization_loss: 0.0000e+00 - total_loss: 556.7844\n", "Epoch 9/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0020 - factorized_top_k/top_5_categorical_accuracy: 0.0135 - factorized_top_k/top_10_categorical_accuracy: 0.0273 - factorized_top_k/top_50_categorical_accuracy: 0.1323 - factorized_top_k/top_100_categorical_accuracy: 0.2429 - loss: 544.5278 - regularization_loss: 0.0000e+00 - total_loss: 544.5278\n", "Epoch 10/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0025 - factorized_top_k/top_5_categorical_accuracy: 0.0156 - factorized_top_k/top_10_categorical_accuracy: 0.0313 - factorized_top_k/top_50_categorical_accuracy: 0.1419 - factorized_top_k/top_100_categorical_accuracy: 0.2569 - loss: 566.6818 - regularization_loss: 0.0000e+00 - total_loss: 566.6818\n", "Epoch 11/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0027 - factorized_top_k/top_5_categorical_accuracy: 0.0161 - factorized_top_k/top_10_categorical_accuracy: 0.0321 - factorized_top_k/top_50_categorical_accuracy: 0.1454 - factorized_top_k/top_100_categorical_accuracy: 0.2649 - loss: 536.5007 - regularization_loss: 0.0000e+00 - total_loss: 536.5007\n", "Epoch 12/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0025 - factorized_top_k/top_5_categorical_accuracy: 0.0176 - factorized_top_k/top_10_categorical_accuracy: 0.0351 - factorized_top_k/top_50_categorical_accuracy: 0.1565 - factorized_top_k/top_100_categorical_accuracy: 0.2797 - loss: 530.5717 - regularization_loss: 0.0000e+00 - total_loss: 530.5717\n", "Epoch 13/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0022 - factorized_top_k/top_5_categorical_accuracy: 0.0170 - factorized_top_k/top_10_categorical_accuracy: 0.0347 - factorized_top_k/top_50_categorical_accuracy: 0.1620 - factorized_top_k/top_100_categorical_accuracy: 0.2865 - loss: 535.3151 - regularization_loss: 0.0000e+00 - total_loss: 535.3151\n", "Epoch 14/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0029 - factorized_top_k/top_5_categorical_accuracy: 0.0195 - factorized_top_k/top_10_categorical_accuracy: 0.0383 - factorized_top_k/top_50_categorical_accuracy: 0.1677 - factorized_top_k/top_100_categorical_accuracy: 0.2930 - loss: 532.3218 - regularization_loss: 0.0000e+00 - total_loss: 532.3218\n", "Epoch 15/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0031 - factorized_top_k/top_5_categorical_accuracy: 0.0198 - factorized_top_k/top_10_categorical_accuracy: 0.0397 - factorized_top_k/top_50_categorical_accuracy: 0.1722 - factorized_top_k/top_100_categorical_accuracy: 0.3015 - loss: 546.2095 - regularization_loss: 0.0000e+00 - total_loss: 546.2095\n", "Epoch 16/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0025 - factorized_top_k/top_5_categorical_accuracy: 0.0192 - factorized_top_k/top_10_categorical_accuracy: 0.0393 - factorized_top_k/top_50_categorical_accuracy: 0.1749 - factorized_top_k/top_100_categorical_accuracy: 0.3061 - loss: 536.6476 - regularization_loss: 0.0000e+00 - total_loss: 536.6476\n", "Epoch 17/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0027 - factorized_top_k/top_5_categorical_accuracy: 0.0204 - factorized_top_k/top_10_categorical_accuracy: 0.0401 - factorized_top_k/top_50_categorical_accuracy: 0.1793 - factorized_top_k/top_100_categorical_accuracy: 0.3136 - loss: 534.1569 - regularization_loss: 0.0000e+00 - total_loss: 534.1569\n", "Epoch 18/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0211 - factorized_top_k/top_10_categorical_accuracy: 0.0432 - factorized_top_k/top_50_categorical_accuracy: 0.1829 - factorized_top_k/top_100_categorical_accuracy: 0.3179 - loss: 527.6975 - regularization_loss: 0.0000e+00 - total_loss: 527.6975\n", "Epoch 19/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0223 - factorized_top_k/top_10_categorical_accuracy: 0.0440 - factorized_top_k/top_50_categorical_accuracy: 0.1873 - factorized_top_k/top_100_categorical_accuracy: 0.3225 - loss: 557.9846 - regularization_loss: 0.0000e+00 - total_loss: 557.9846\n", "Epoch 20/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0216 - factorized_top_k/top_10_categorical_accuracy: 0.0445 - factorized_top_k/top_50_categorical_accuracy: 0.1910 - factorized_top_k/top_100_categorical_accuracy: 0.3283 - loss: 535.7236 - regularization_loss: 0.0000e+00 - total_loss: 535.7236\n" ] } ], "source": [ "model = MovielensModel([64, 32])\n", "model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))\n", "\n", "two_layer_history = model.fit(\n", " cached_train,\n", " epochs=num_epochs,\n", " verbose=2)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 13, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "5/5 [==============================] - 2s 410ms/step - factorized_top_k/top_1_categorical_accuracy: 3.5000e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0068 - factorized_top_k/top_10_categorical_accuracy: 0.0172 - factorized_top_k/top_50_categorical_accuracy: 0.1191 - factorized_top_k/top_100_categorical_accuracy: 0.2455 - loss: 30779.3122 - regularization_loss: 0.0000e+00 - total_loss: 30779.3122\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n" ] }, { "data": { "text/plain": "{'factorized_top_k/top_1_categorical_accuracy': 0.0003499999875202775,\n 'factorized_top_k/top_5_categorical_accuracy': 0.006800000090152025,\n 'factorized_top_k/top_10_categorical_accuracy': 0.017249999567866325,\n 'factorized_top_k/top_50_categorical_accuracy': 0.11909999698400497,\n 'factorized_top_k/top_100_categorical_accuracy': 0.2454500049352646,\n 'loss': 28001.8125,\n 'regularization_loss': 0,\n 'total_loss': 28001.8125}" }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(cached_test, return_dict=True)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "40/40 - 11s - factorized_top_k/top_1_categorical_accuracy: 0.1176 - factorized_top_k/top_5_categorical_accuracy: 0.1213 - factorized_top_k/top_10_categorical_accuracy: 0.1262 - factorized_top_k/top_50_categorical_accuracy: 0.1358 - factorized_top_k/top_100_categorical_accuracy: 0.1475 - loss: 618.6954 - regularization_loss: 0.0000e+00 - total_loss: 618.6954\n", "Epoch 2/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0082 - factorized_top_k/top_5_categorical_accuracy: 0.0136 - factorized_top_k/top_10_categorical_accuracy: 0.0188 - factorized_top_k/top_50_categorical_accuracy: 0.0438 - factorized_top_k/top_100_categorical_accuracy: 0.0740 - loss: 603.6727 - regularization_loss: 0.0000e+00 - total_loss: 603.6727\n", "Epoch 3/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0076 - factorized_top_k/top_5_categorical_accuracy: 0.0189 - factorized_top_k/top_10_categorical_accuracy: 0.0281 - factorized_top_k/top_50_categorical_accuracy: 0.0685 - factorized_top_k/top_100_categorical_accuracy: 0.1094 - loss: 589.9850 - regularization_loss: 0.0000e+00 - total_loss: 589.9850\n", "Epoch 4/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0064 - factorized_top_k/top_5_categorical_accuracy: 0.0135 - factorized_top_k/top_10_categorical_accuracy: 0.0192 - factorized_top_k/top_50_categorical_accuracy: 0.0615 - factorized_top_k/top_100_categorical_accuracy: 0.1104 - loss: 573.2556 - regularization_loss: 0.0000e+00 - total_loss: 573.2556\n", "Epoch 5/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0019 - factorized_top_k/top_5_categorical_accuracy: 0.0077 - factorized_top_k/top_10_categorical_accuracy: 0.0135 - factorized_top_k/top_50_categorical_accuracy: 0.0611 - factorized_top_k/top_100_categorical_accuracy: 0.1208 - loss: 584.2108 - regularization_loss: 0.0000e+00 - total_loss: 584.2108\n", "Epoch 6/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0017 - factorized_top_k/top_5_categorical_accuracy: 0.0065 - factorized_top_k/top_10_categorical_accuracy: 0.0123 - factorized_top_k/top_50_categorical_accuracy: 0.0642 - factorized_top_k/top_100_categorical_accuracy: 0.1315 - loss: 555.7081 - regularization_loss: 0.0000e+00 - total_loss: 555.7081\n", "Epoch 7/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 8.3750e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0046 - factorized_top_k/top_10_categorical_accuracy: 0.0098 - factorized_top_k/top_50_categorical_accuracy: 0.0636 - factorized_top_k/top_100_categorical_accuracy: 0.1417 - loss: 574.5372 - regularization_loss: 0.0000e+00 - total_loss: 574.5372\n", "Epoch 8/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0052 - factorized_top_k/top_5_categorical_accuracy: 0.0123 - factorized_top_k/top_10_categorical_accuracy: 0.0203 - factorized_top_k/top_50_categorical_accuracy: 0.0842 - factorized_top_k/top_100_categorical_accuracy: 0.1606 - loss: 552.7963 - regularization_loss: 0.0000e+00 - total_loss: 552.7963\n", "Epoch 9/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0028 - factorized_top_k/top_5_categorical_accuracy: 0.0090 - factorized_top_k/top_10_categorical_accuracy: 0.0165 - factorized_top_k/top_50_categorical_accuracy: 0.0868 - factorized_top_k/top_100_categorical_accuracy: 0.1723 - loss: 558.8929 - regularization_loss: 0.0000e+00 - total_loss: 558.8929\n", "Epoch 10/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0034 - factorized_top_k/top_5_categorical_accuracy: 0.0101 - factorized_top_k/top_10_categorical_accuracy: 0.0190 - factorized_top_k/top_50_categorical_accuracy: 0.0914 - factorized_top_k/top_100_categorical_accuracy: 0.1850 - loss: 546.2024 - regularization_loss: 0.0000e+00 - total_loss: 546.2024\n", "Epoch 11/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0069 - factorized_top_k/top_5_categorical_accuracy: 0.0177 - factorized_top_k/top_10_categorical_accuracy: 0.0295 - factorized_top_k/top_50_categorical_accuracy: 0.1083 - factorized_top_k/top_100_categorical_accuracy: 0.1974 - loss: 564.8221 - regularization_loss: 0.0000e+00 - total_loss: 564.8221\n", "Epoch 12/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0064 - factorized_top_k/top_5_categorical_accuracy: 0.0184 - factorized_top_k/top_10_categorical_accuracy: 0.0300 - factorized_top_k/top_50_categorical_accuracy: 0.1110 - factorized_top_k/top_100_categorical_accuracy: 0.2063 - loss: 546.9897 - regularization_loss: 0.0000e+00 - total_loss: 546.9897\n", "Epoch 13/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0059 - factorized_top_k/top_5_categorical_accuracy: 0.0166 - factorized_top_k/top_10_categorical_accuracy: 0.0282 - factorized_top_k/top_50_categorical_accuracy: 0.1153 - factorized_top_k/top_100_categorical_accuracy: 0.2180 - loss: 574.3752 - regularization_loss: 0.0000e+00 - total_loss: 574.3752\n", "Epoch 14/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0051 - factorized_top_k/top_5_categorical_accuracy: 0.0163 - factorized_top_k/top_10_categorical_accuracy: 0.0291 - factorized_top_k/top_50_categorical_accuracy: 0.1191 - factorized_top_k/top_100_categorical_accuracy: 0.2230 - loss: 548.7125 - regularization_loss: 0.0000e+00 - total_loss: 548.7125\n", "Epoch 15/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0060 - factorized_top_k/top_5_categorical_accuracy: 0.0163 - factorized_top_k/top_10_categorical_accuracy: 0.0293 - factorized_top_k/top_50_categorical_accuracy: 0.1248 - factorized_top_k/top_100_categorical_accuracy: 0.2305 - loss: 556.8916 - regularization_loss: 0.0000e+00 - total_loss: 556.8916\n", "Epoch 16/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0057 - factorized_top_k/top_5_categorical_accuracy: 0.0166 - factorized_top_k/top_10_categorical_accuracy: 0.0289 - factorized_top_k/top_50_categorical_accuracy: 0.1266 - factorized_top_k/top_100_categorical_accuracy: 0.2365 - loss: 536.3606 - regularization_loss: 0.0000e+00 - total_loss: 536.3606\n", "Epoch 17/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0054 - factorized_top_k/top_5_categorical_accuracy: 0.0177 - factorized_top_k/top_10_categorical_accuracy: 0.0316 - factorized_top_k/top_50_categorical_accuracy: 0.1289 - factorized_top_k/top_100_categorical_accuracy: 0.2398 - loss: 555.4509 - regularization_loss: 0.0000e+00 - total_loss: 555.4509\n", "Epoch 18/20\n", "40/40 - 12s - factorized_top_k/top_1_categorical_accuracy: 0.0069 - factorized_top_k/top_5_categorical_accuracy: 0.0195 - factorized_top_k/top_10_categorical_accuracy: 0.0339 - factorized_top_k/top_50_categorical_accuracy: 0.1330 - factorized_top_k/top_100_categorical_accuracy: 0.2458 - loss: 556.9831 - regularization_loss: 0.0000e+00 - total_loss: 556.9831\n", "Epoch 19/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0039 - factorized_top_k/top_5_categorical_accuracy: 0.0155 - factorized_top_k/top_10_categorical_accuracy: 0.0298 - factorized_top_k/top_50_categorical_accuracy: 0.1346 - factorized_top_k/top_100_categorical_accuracy: 0.2520 - loss: 546.2797 - regularization_loss: 0.0000e+00 - total_loss: 546.2797\n", "Epoch 20/20\n", "40/40 - 10s - factorized_top_k/top_1_categorical_accuracy: 0.0050 - factorized_top_k/top_5_categorical_accuracy: 0.0171 - factorized_top_k/top_10_categorical_accuracy: 0.0319 - factorized_top_k/top_50_categorical_accuracy: 0.1383 - factorized_top_k/top_100_categorical_accuracy: 0.2569 - loss: 556.3168 - regularization_loss: 0.0000e+00 - total_loss: 556.3168\n" ] } ], "source": [ "model = MovielensModel([128, 64, 32])\n", "model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))\n", "\n", "three_layer_history = model.fit(\n", " cached_train,\n", " epochs=num_epochs,\n", " verbose=2)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 15, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Using a while_loop for converting BoostedTreesBucketize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "5/5 [==============================] - 2s 407ms/step - factorized_top_k/top_1_categorical_accuracy: 6.5000e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0068 - factorized_top_k/top_10_categorical_accuracy: 0.0170 - factorized_top_k/top_50_categorical_accuracy: 0.1090 - factorized_top_k/top_100_categorical_accuracy: 0.2250 - loss: 30809.3079 - regularization_loss: 0.0000e+00 - total_loss: 30809.3079\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n" ] }, { "data": { "text/plain": "{'factorized_top_k/top_1_categorical_accuracy': 0.0006500000017695129,\n 'factorized_top_k/top_5_categorical_accuracy': 0.006750000175088644,\n 'factorized_top_k/top_10_categorical_accuracy': 0.017000000923871994,\n 'factorized_top_k/top_50_categorical_accuracy': 0.10904999822378159,\n 'factorized_top_k/top_100_categorical_accuracy': 0.22495000064373016,\n 'loss': 28020.451171875,\n 'regularization_loss': 0,\n 'total_loss': 28020.451171875}" }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(cached_test, return_dict=True)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 15, "outputs": [], "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }