From dbd3088c63acff508c46c2368838413edeed08f3 Mon Sep 17 00:00:00 2001 From: Neal Wu Date: Mon, 6 Nov 2017 19:41:51 -0800 Subject: [PATCH] Make wide_deep_test much less flaky --- official/wide_deep/wide_deep_test.csv | 10 ++++++++++ official/wide_deep/wide_deep_test.py | 11 ++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/official/wide_deep/wide_deep_test.csv b/official/wide_deep/wide_deep_test.csv index f1814ad66..374397dbd 100644 --- a/official/wide_deep/wide_deep_test.csv +++ b/official/wide_deep/wide_deep_test.csv @@ -18,3 +18,13 @@ 32,Private,186824,HS-grad,9,Never-married,Machine-op-inspct,Unmarried,,,0,0,40,,<=50K 38,Private,28887,11th,7,Married-civ-spouse,Sales,Husband,,,0,0,50,,<=50K 43,Self-emp-not-inc,292175,Masters,14,Divorced,Exec-managerial,Unmarried,,,0,0,45,,>50K +40,Private,193524,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,,,0,0,60,,>50K +56,Local-gov,216851,Bachelors,13,Married-civ-spouse,Tech-support,Husband,,,0,0,40,,>50K +54,?,180211,Some-college,10,Married-civ-spouse,?,Husband,,,0,0,60,,>50K +22,State-gov,311512,Some-college,10,Married-civ-spouse,Other-service,Husband,,,0,0,15,,<=50K +31,Private,84154,Some-college,10,Married-civ-spouse,Sales,Husband,,,0,0,38,,>50K +57,Federal-gov,337895,Bachelors,13,Married-civ-spouse,Prof-specialty,Husband,,,0,0,40,,>50K +47,Private,51835,Prof-school,15,Married-civ-spouse,Prof-specialty,Wife,,,0,1902,60,,>50K +50,Federal-gov,251585,Bachelors,13,Divorced,Exec-managerial,Not-in-family,,,0,0,55,,>50K +25,Private,289980,HS-grad,9,Never-married,Handlers-cleaners,Not-in-family,,,0,0,35,,<=50K +42,Private,116632,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,,,0,0,45,,>50K diff --git a/official/wide_deep/wide_deep_test.py b/official/wide_deep/wide_deep_test.py index ff19a699d..b09cbdd54 100644 --- a/official/wide_deep/wide_deep_test.py +++ b/official/wide_deep/wide_deep_test.py @@ -85,18 +85,23 @@ class BaseTest(tf.test.TestCase): input_fn=lambda: wide_deep.input_fn( TEST_CSV, num_epochs=1, shuffle=False, batch_size=1)) - # Train for 40 steps at batch size 2 and evaluate final loss + # Train for 100 epochs at batch size 3 and evaluate final loss model.train( input_fn=lambda: wide_deep.input_fn( - TEST_CSV, num_epochs=None, shuffle=True, batch_size=2), - steps=40) + TEST_CSV, num_epochs=100, shuffle=True, batch_size=3)) final_results = model.evaluate( input_fn=lambda: wide_deep.input_fn( TEST_CSV, num_epochs=1, shuffle=False, batch_size=1)) print('%s initial results:' % model_type, initial_results) print('%s final results:' % model_type, final_results) + + # Ensure loss has decreased, while accuracy and both AUCs have increased. self.assertLess(final_results['loss'], initial_results['loss']) + self.assertGreater(final_results['auc'], initial_results['auc']) + self.assertGreater(final_results['auc_precision_recall'], + initial_results['auc_precision_recall']) + self.assertGreater(final_results['accuracy'], initial_results['accuracy']) def test_wide_deep_estimator_training(self): self.build_and_test_estimator('wide_deep') -- GitLab