In [1]:
import os
import tempfile

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

import tensorflow_recommenders as tfrs

In [2]:
ratings = tfds.load("movielens/100k-ratings", split="train")
movies = tfds.load("movielens/100k-movies", split="train")

ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "timestamp": x["timestamp"],
})
movies = movies.map(lambda x: x["movie_title"])

In [3]:
timestamps = np.concatenate(list(ratings.map(lambda x: x["timestamp"]).batch(100)))

max_timestamp = timestamps.max()
min_timestamp = timestamps.min()

timestamp_buckets = np.linspace(
    min_timestamp, max_timestamp, num=1000,
)

unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(
    lambda x: x["user_id"]))))

In [4]:
class UserModel(tf.keras.Model):

  def __init__(self, use_timestamps):
    super().__init__()

    self._use_timestamps = use_timestamps

    self.user_embedding = tf.keras.Sequential([
        tf.keras.layers.experimental.preprocessing.StringLookup(
            vocabulary=unique_user_ids, mask_token=None),
        tf.keras.layers.Embedding(len(unique_user_ids) + 1, 32),
    ])

    if use_timestamps:
      self.timestamp_embedding = tf.keras.Sequential([
          tf.keras.layers.experimental.preprocessing.Discretization(timestamp_buckets.tolist()),
          tf.keras.layers.Embedding(len(timestamp_buckets) + 1, 32),
      ])
      self.normalized_timestamp = tf.keras.layers.experimental.preprocessing.Normalization()

      self.normalized_timestamp.adapt(timestamps)

  def call(self, inputs):
    if not self._use_timestamps:
      return self.user_embedding(inputs["user_id"])

    return tf.concat([
        self.user_embedding(inputs["user_id"]),
        self.timestamp_embedding(inputs["timestamp"]),
        self.normalized_timestamp(inputs["timestamp"]),
    ], axis=1)

In [5]:
class MovieModel(tf.keras.Model):

  def __init__(self):
    super().__init__()

    max_tokens = 10_000

    self.title_embedding = tf.keras.Sequential([
      tf.keras.layers.experimental.preprocessing.StringLookup(
          vocabulary=unique_movie_titles, mask_token=None),
      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, 32)
    ])

    self.title_vectorizer = tf.keras.layers.experimental.preprocessing.TextVectorization(
        max_tokens=max_tokens)

    self.title_text_embedding = tf.keras.Sequential([
      self.title_vectorizer,
      tf.keras.layers.Embedding(max_tokens, 32, mask_zero=True),
      tf.keras.layers.GlobalAveragePooling1D(),
    ])

    self.title_vectorizer.adapt(movies)

  def call(self, titles):
    return tf.concat([
        self.title_embedding(titles),
        self.title_text_embedding(titles),
    ], axis=1)

In [6]:
class MovielensModel(tfrs.models.Model):

  def __init__(self, use_timestamps):
    super().__init__()
    self.query_model = tf.keras.Sequential([
      UserModel(use_timestamps),
      tf.keras.layers.Dense(32)
    ])
    self.candidate_model = tf.keras.Sequential([
      MovieModel(),
      tf.keras.layers.Dense(32)
    ])
    self.task = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=movies.batch(128).map(self.candidate_model),
        ),
    )

  def compute_loss(self, features, training=False):
    # We only pass the user id and timestamp features into the query model. This
    # is to ensure that the training inputs would have the same keys as the
    # query inputs. Otherwise the discrepancy in input structure would cause an
    # error when loading the query model after saving it.
    query_embeddings = self.query_model({
        "user_id": features["user_id"],
        "timestamp": features["timestamp"],
    })
    movie_embeddings = self.candidate_model(features["movie_title"])

    return self.task(query_embeddings, movie_embeddings)

In [7]:
tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

cached_train = train.shuffle(100_000).batch(2048)
cached_test = test.batch(4096).cache()

In [8]:
model = MovielensModel(use_timestamps=False)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)

train_accuracy = model.evaluate(
    cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(
    cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")

Epoch 1/3
Consider rewriting this model with the Functional API.


Consider rewriting this model with the Functional API.


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


Consider rewriting this model with the Functional API.


Consider rewriting this model with the Functional API.


Epoch 2/3
Epoch 3/3
Consider rewriting this model with the Functional API.


Consider rewriting this model with the Functional API.


Top-100 accuracy (train): 0.29.
Top-100 accuracy (test): 0.21.


In [9]:
model = MovielensModel(use_timestamps=True)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)

train_accuracy = model.evaluate(
    cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(
    cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")

Epoch 1/3
Consider rewriting this model with the Functional API.


Consider rewriting this model with the Functional API.










Consider rewriting this model with the Functional API.


Consider rewriting this model with the Functional API.






Epoch 2/3
Epoch 3/3
Consider rewriting this model with the Functional API.


Consider rewriting this model with the Functional API.






Top-100 accuracy (train): 0.36.
Top-100 accuracy (test): 0.25.


In [10]:
scann_index = tfrs.layers.factorized_top_k.ScaNN(model.query_model)
scann_index.index(movies.batch(100).map(model.candidate_model), movies)

<tensorflow_recommenders.layers.factorized_top_k.ScaNN at 0x7f450058e208>

In [11]:
dict_batch = next(iter(cached_train))

In [12]:
user_id = dict_batch['user_id']
timestamp = dict_batch['timestamp']

In [13]:
user_id[0], timestamp[0]

(<tf.Tensor: shape=(), dtype=string, numpy=b'178'>,
 <tf.Tensor: shape=(), dtype=int64, numpy=882826556>)

In [14]:
dict_input = {'user_id': '327', 'timestamp': 887745662}

In [15]:
# Get recommendations.
_, titles = scann_index(dict_batch)

array([b'Babe (1995)', b'My Man Godfrey (1936)', b'Chasing Amy (1997)',
       ..., b'Tin Cup (1996)', b"What's Love Got to Do with It (1993)",
       b'Jerry Maguire (1996)'], dtype=object)>, 'user_id': <tf.Tensor: shape=(2048,), dtype=string, numpy=array([b'178', b'370', b'710', ..., b'943', b'406', b'69'], dtype=object)>, 'timestamp': <tf.Tensor: shape=(2048,), dtype=int64, numpy=
array([882826556, 879434587, 882063276, ..., 875502192, 882480890,
       882072920])>}
Consider rewriting this model with the Functional API.


array([b'Babe (1995)', b'My Man Godfrey (1936)', b'Chasing Amy (1997)',
       ..., b'Tin Cup (1996)', b"What's Love Got to Do with It (1993)",
       b'Jerry Maguire (1996)'], dtype=object)>, 'user_id': <tf.Tensor: shape=(2048,), dtype=string, numpy=array([b'178', b'370', b'710', ..., b'943', b'406', b'69'], dtype=object)>, 'timestamp': <tf.Tensor: shape=(2048,), dtype=int64, numpy=
array([882826556, 879434587, 882063276, ..., 875502192, 882480890,
       882072920])>}
Consider rewriting this model with the Functional API.






In [16]:
print(f"Recommendations for user 42: {titles[0, :3]}")

Recommendations for user 42: [b'Scarlet Letter, The (1926)' b'In the Line of Fire (1993)'
 b'Days of Thunder (1990)']
