From aec1fec62873b52fe1925b5512c8a99bc9ac89b9 Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Thu, 4 Oct 2018 17:00:36 -0700 Subject: [PATCH] Fix/ncf eval default (#5438) * improve default handling for eval_batch_size * return eval_batch_size default to None * fix syntax error --- official/recommendation/ncf_main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/official/recommendation/ncf_main.py b/official/recommendation/ncf_main.py index f83787162..f042d31f4 100644 --- a/official/recommendation/ncf_main.py +++ b/official/recommendation/ncf_main.py @@ -128,8 +128,9 @@ def run_ncf(_): batch_size = distribution_utils.per_device_batch_size( int(FLAGS.batch_size), num_gpus) - eval_batch_size = int(FLAGS.eval_batch_size or FLAGS.batch_size) eval_per_user = rconst.NUM_EVAL_NEGATIVES + 1 + eval_batch_size = int(FLAGS.eval_batch_size or + max([FLAGS.batch_size, eval_per_user])) if eval_batch_size % eval_per_user: eval_batch_size = eval_batch_size // eval_per_user * eval_per_user tf.logging.warning( @@ -365,7 +366,8 @@ def define_ncf_flags(): @flags.validator("eval_batch_size", "eval_batch_size must be at least {}" .format(rconst.NUM_EVAL_NEGATIVES + 1)) def eval_size_check(eval_batch_size): - return int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES + return (eval_batch_size is None or + int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES) if __name__ == "__main__": -- GitLab