diff --git a/official/recommendation/ncf_main.py b/official/recommendation/ncf_main.py index f837871628a7b627dc9b44d984830095fc8b1023..f042d31f43bb966bd30c8a19f4ee85df4f9be9e3 100644 --- a/official/recommendation/ncf_main.py +++ b/official/recommendation/ncf_main.py @@ -128,8 +128,9 @@ def run_ncf(_): batch_size = distribution_utils.per_device_batch_size( int(FLAGS.batch_size), num_gpus) - eval_batch_size = int(FLAGS.eval_batch_size or FLAGS.batch_size) eval_per_user = rconst.NUM_EVAL_NEGATIVES + 1 + eval_batch_size = int(FLAGS.eval_batch_size or + max([FLAGS.batch_size, eval_per_user])) if eval_batch_size % eval_per_user: eval_batch_size = eval_batch_size // eval_per_user * eval_per_user tf.logging.warning( @@ -365,7 +366,8 @@ def define_ncf_flags(): @flags.validator("eval_batch_size", "eval_batch_size must be at least {}" .format(rconst.NUM_EVAL_NEGATIVES + 1)) def eval_size_check(eval_batch_size): - return int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES + return (eval_batch_size is None or + int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES) if __name__ == "__main__":