提交 91000bc5 编写于 作者: S Shawn Wang

Refactor neumf_model.py to support users who just need top_k and ndcg tensors.

上级 69b01644
......@@ -36,12 +36,13 @@ from __future__ import print_function
import sys
import typing
import google3
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from official.datasets import movielens # pylint: disable=g-bad-import-order
from official.recommendation import constants as rconst
from official.recommendation import stat_utils
from google3.third_party.tensorflow_models.official.datasets import movielens # pylint: disable=g-bad-import-order
from google3.third_party.tensorflow_models.official.recommendation import constants as rconst
from google3.third_party.tensorflow_models.official.recommendation import stat_utils
def _sparse_to_dense_grads(grads_and_vars):
......@@ -298,39 +299,8 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
Returns:
An EstimatorSpec for evaluation.
"""
logits_by_user = tf.reshape(logits, (-1, rconst.NUM_EVAL_NEGATIVES + 1))
duplicate_mask_by_user = tf.reshape(duplicate_mask,
(-1, rconst.NUM_EVAL_NEGATIVES + 1))
if match_mlperf:
# Set duplicate logits to the min value for that dtype. The MLPerf
# reference dedupes during evaluation.
logits_by_user *= (1 - duplicate_mask_by_user)
logits_by_user += duplicate_mask_by_user * logits_by_user.dtype.min
# Determine the location of the first element in each row after the elements
# are sorted.
sort_indices = tf.contrib.framework.argsort(
logits_by_user, axis=1, direction="DESCENDING")
# Use matrix multiplication to extract the position of the true item from the
# tensor of sorted indices. This approach is chosen because both GPUs and TPUs
# perform matrix multiplications very quickly. This is similar to np.argwhere.
# However this is a special case because the target will only appear in
# sort_indices once.
one_hot_position = tf.cast(tf.equal(sort_indices, 0), tf.int32)
sparse_positions = tf.multiply(
one_hot_position, tf.range(logits_by_user.shape[1])[tf.newaxis, :])
position_vector = tf.reduce_sum(sparse_positions, axis=1)
in_top_k = tf.cast(tf.less(position_vector, rconst.TOP_K), tf.float32)
ndcg = tf.log(2.) / tf.log(tf.cast(position_vector, tf.float32) + 2)
ndcg *= in_top_k
# If a row is a padded row, all but the first element will be a duplicate.
metric_weights = tf.not_equal(tf.reduce_sum(duplicate_mask_by_user, axis=1),
rconst.NUM_EVAL_NEGATIVES)
in_top_k, ndcg, metric_weights, logits_by_user = compute_top_k_and_ndcg(
logits, duplicate_mask, match_mlperf)
# Examples are provided by the eval Dataset in a structured format, so eval
# labels can be reconstructed on the fly.
......@@ -375,3 +345,60 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
loss=cross_entropy,
eval_metric_ops=metric_fn(in_top_k, ndcg, metric_weights)
)
def compute_top_k_and_ndcg(logits, # type: tf.Tensor
duplicate_mask, # type: tf.Tensor
match_mlperf=False # type: bool
):
"""Compute inputs of metric calculation.
Args:
logits: A tensor containing the predicted logits for each user. The shape
of logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits
for a user are grouped, and the first element of the group is the true
element.
duplicate_mask: A vector with the same shape as logits, with a value of 1
if the item corresponding to the logit at that position has already
appeared for that user.
match_mlperf: Use the MLPerf reference convention for computing rank.
Returns:
is_top_k, ndcg and weights, all of which has size (num_users_in_batch,), and
logits_by_user which has size
(num_users_in_batch, (rconst.NUM_EVAL_NEGATIVES + 1)).
"""
logits_by_user = tf.reshape(logits, (-1, rconst.NUM_EVAL_NEGATIVES + 1))
duplicate_mask_by_user = tf.reshape(duplicate_mask,
(-1, rconst.NUM_EVAL_NEGATIVES + 1))
if match_mlperf:
# Set duplicate logits to the min value for that dtype. The MLPerf
# reference dedupes during evaluation.
logits_by_user *= (1 - duplicate_mask_by_user)
logits_by_user += duplicate_mask_by_user * logits_by_user.dtype.min
# Determine the location of the first element in each row after the elements
# are sorted.
sort_indices = tf.contrib.framework.argsort(
logits_by_user, axis=1, direction="DESCENDING")
# Use matrix multiplication to extract the position of the true item from the
# tensor of sorted indices. This approach is chosen because both GPUs and TPUs
# perform matrix multiplications very quickly. This is similar to np.argwhere.
# However this is a special case because the target will only appear in
# sort_indices once.
one_hot_position = tf.cast(tf.equal(sort_indices, 0), tf.int32)
sparse_positions = tf.multiply(
one_hot_position, tf.range(logits_by_user.shape[1])[tf.newaxis, :])
position_vector = tf.reduce_sum(sparse_positions, axis=1)
in_top_k = tf.cast(tf.less(position_vector, rconst.TOP_K), tf.float32)
ndcg = tf.log(2.) / tf.log(tf.cast(position_vector, tf.float32) + 2)
ndcg *= in_top_k
# If a row is a padded row, all but the first element will be a duplicate.
metric_weights = tf.not_equal(tf.reduce_sum(duplicate_mask_by_user, axis=1),
rconst.NUM_EVAL_NEGATIVES)
return in_top_k, ndcg, metric_weights, logits_by_user
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册