提交 9f81374c 编写于 作者: R raymondxyang 提交者: Rasmus Munk Larsen

Add option for build more python tests in Cmake (#11853)

* Ignore Windows built project

* Fix deprecated methods in tf.contrib.python

* Fix regex match for Windows build in contrib.keras

* Fix Regex match for Windows build in session_bundle

* * Fix deprecated methods
* Fix regex match for Windows
* Fix compatibility issue with Python 3.x

* Add missing ops into Windows build for test

* Enabled more testcases for Windows build

* Clean code and fix typo

* Add conditional cmake mode for enabling more unit testcase

* Add Cmake mode for major Contrib packages

* Add supplementary info in RAEDME for new cmake option

* * Update tf_tests after testing with TF 1.3
* Clean code and resolve conflicts

* Fix unsafe regex matches and format code

* Update exclude list after testing with latest master branch

* Fix missing module
上级 98f0e1ef
...@@ -13,4 +13,5 @@ node_modules ...@@ -13,4 +13,5 @@ node_modules
__pycache__ __pycache__
*.swp *.swp
.vscode/ .vscode/
cmake_build/
.idea/** .idea/**
...@@ -29,6 +29,7 @@ option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON) ...@@ -29,6 +29,7 @@ option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON) option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON)
option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF) option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF)
option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF) option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF)
option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for contrib packages" OFF)
option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF)
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions")
......
...@@ -241,6 +241,13 @@ Step-by-step Windows build ...@@ -241,6 +241,13 @@ Step-by-step Windows build
``` ```
ctest -C RelWithDebInfo ctest -C RelWithDebInfo
``` ```
* `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on
serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`.
After building the python wheel, you need to install the new wheel before running the tests.
To execute the tests, use
```
ctest -C RelWithDebInfo
```
4. Invoke MSBuild to build TensorFlow. 4. Invoke MSBuild to build TensorFlow.
......
...@@ -76,7 +76,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) ...@@ -76,7 +76,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc" #"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/generate_vocab_remapping_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/generate_vocab_remapping_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
......
...@@ -156,6 +156,21 @@ if (tensorflow_BUILD_PYTHON_TESTS) ...@@ -156,6 +156,21 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/*_test.py"
) )
if (tensorflow_BUILD_MORE_PYTHON_TESTS)
# Adding other major packages
file(GLOB_RECURSE tf_test_src_py
${tf_test_src_py}
"${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/keras/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/*_test.py"
)
endif()
# exclude the ones we don't want # exclude the ones we don't want
set(tf_test_src_py_exclude set(tf_test_src_py_exclude
# Python source line inspection tests are flaky on Windows (b/36375074). # Python source line inspection tests are flaky on Windows (b/36375074).
...@@ -183,6 +198,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) ...@@ -183,6 +198,9 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Loading resources in contrib doesn't seem to work on Windows # Loading resources in contrib doesn't seem to work on Windows
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/random_forest_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/random_forest_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py"
# dask need fix
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py"
# Test is flaky on Windows GPU builds (b/38283730). # Test is flaky on Windows GPU builds (b/38283730).
"${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py"
) )
...@@ -215,11 +233,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) ...@@ -215,11 +233,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py"
# training tests # training tests
"${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix. "${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix.
"${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # Needs tf.contrib fix.
"${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker.
"${tensorflow_source_dir}/tensorflow/python/training/monitored_session_test.py" # Needs tf.contrib fix.
"${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows. "${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows.
"${tensorflow_source_dir}/tensorflow/python/training/saver_large_variable_test.py" # Overflow error.
"${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename. "${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename.
"${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker.
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops
...@@ -233,6 +248,45 @@ if (tensorflow_BUILD_PYTHON_TESTS) ...@@ -233,6 +248,45 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support "${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support
# Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug. # Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug.
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows.
# Dask.Dataframe bugs on Window Build
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/io_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/graph_actions_test.py"
# Need extra build
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_distribution_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py"
# Windows Path
"${tensorflow_source_dir}/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py" #TODO: Fix path
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/models_test.py"
# Related to Windows Multiprocessing https://github.com/fchollet/keras/issues/5071
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/engine/training_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/callbacks_test.py"
# Scipy needed
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py"
# Failing with TF 1.3 (TODO)
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py"
) )
endif() endif()
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
......
...@@ -49,6 +49,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import * ...@@ -49,6 +49,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.sample_stats import * from tensorflow.contrib.distributions.python.ops.sample_stats import *
from tensorflow.contrib.distributions.python.ops.test_util import *
from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import *
from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import * from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import *
from tensorflow.contrib.distributions.python.ops.wishart import * from tensorflow.contrib.distributions.python.ops.wishart import *
......
...@@ -562,7 +562,7 @@ def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False): ...@@ -562,7 +562,7 @@ def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False):
grouped_vars[ckpt_name].append(var) grouped_vars[ckpt_name].append(var)
else: else:
for ckpt_name, value in var_list.iteritems(): for ckpt_name, value in var_list.items():
if isinstance(value, (tuple, list)): if isinstance(value, (tuple, list)):
grouped_vars[ckpt_name] = value grouped_vars[ckpt_name] = value
else: else:
......
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import marshal import marshal
import os
import sys import sys
import time import time
import types as python_types import types as python_types
...@@ -195,7 +196,10 @@ def func_dump(func): ...@@ -195,7 +196,10 @@ def func_dump(func):
Returns: Returns:
A tuple `(code, defaults, closure)`. A tuple `(code, defaults, closure)`.
""" """
code = marshal.dumps(func.__code__).decode('raw_unicode_escape') if os.name == 'nt':
code = marshal.dumps(func.__code__).replace(b'\\',b'/').decode('raw_unicode_escape')
else:
code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
defaults = func.__defaults__ defaults = func.__defaults__
if func.__closure__: if func.__closure__:
closure = tuple(c.cell_contents for c in func.__closure__) closure = tuple(c.cell_contents for c in func.__closure__)
......
...@@ -505,7 +505,7 @@ class EstimatorModelFnTest(test.TestCase): ...@@ -505,7 +505,7 @@ class EstimatorModelFnTest(test.TestCase):
return input_fn_utils.InputFnOps( return input_fn_utils.InputFnOps(
features, labels, {'examples': serialized_tf_example}) features, labels, {'examples': serialized_tf_example})
est.export_savedmodel(est.model_dir + '/export', serving_input_fn) est.export_savedmodel(os.path.join(est.model_dir, 'export'), serving_input_fn)
self.assertTrue(self.mock_saver.restore.called) self.assertTrue(self.mock_saver.restore.called)
...@@ -955,10 +955,11 @@ class EstimatorTest(test.TestCase): ...@@ -955,10 +955,11 @@ class EstimatorTest(test.TestCase):
self.assertTrue('input_example_tensor' in graph_ops) self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops) self.assertTrue('ParseExample/ParseExample' in graph_ops)
self.assertTrue('linear/linear/feature/matmul' in graph_ops) self.assertTrue('linear/linear/feature/matmul' in graph_ops)
self.assertSameElements( self.assertItemsEqual(
['bogus_lookup', 'feature'], ['bogus_lookup', 'feature'],
graph.get_collection( [compat.as_str_any(x) for x in graph.get_collection(
constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)) constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)])
# cleanup # cleanup
gfile.DeleteRecursively(tmpdir) gfile.DeleteRecursively(tmpdir)
......
...@@ -31,6 +31,7 @@ from tensorflow.contrib.session_bundle import exporter ...@@ -31,6 +31,7 @@ from tensorflow.contrib.session_bundle import exporter
from tensorflow.contrib.session_bundle import manifest_pb2 from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
...@@ -49,9 +50,8 @@ def _training_input_fn(): ...@@ -49,9 +50,8 @@ def _training_input_fn():
class ExportTest(test.TestCase): class ExportTest(test.TestCase):
def _get_default_signature(self, export_meta_filename): def _get_default_signature(self, export_meta_filename):
"""Gets the default signature from the export.meta file.""" """ Gets the default signature from the export.meta file. """
with session.Session(): with session.Session():
save = saver.import_meta_graph(export_meta_filename) save = saver.import_meta_graph(export_meta_filename)
meta_graph_def = save.export_meta_graph() meta_graph_def = save.export_meta_graph()
...@@ -68,18 +68,19 @@ class ExportTest(test.TestCase): ...@@ -68,18 +68,19 @@ class ExportTest(test.TestCase):
self.assertTrue(gfile.Exists(export_dir)) self.assertTrue(gfile.Exists(export_dir))
# Only the written checkpoints are exported. # Only the written checkpoints are exported.
self.assertTrue( self.assertTrue(
saver.checkpoint_exists(export_dir + '00000001/export'), saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')),
'Exported checkpoint expected but not found: %s' % 'Exported checkpoint expected but not found: %s' %
(export_dir + '00000001/export')) os.path.join(export_dir, '00000001', 'export'))
self.assertTrue( self.assertTrue(
saver.checkpoint_exists(export_dir + '00000010/export'), saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')),
'Exported checkpoint expected but not found: %s' % 'Exported checkpoint expected but not found: %s' %
(export_dir + '00000010/export')) os.path.join(export_dir, '00000010', 'export'))
self.assertEquals( self.assertEquals(
six.b(os.path.join(export_dir, '00000010')), six.b(os.path.join(export_dir, '00000010')),
export_monitor.last_export_dir) export_monitor.last_export_dir)
# Validate the signature # Validate the signature
signature = self._get_default_signature(export_dir + '00000010/export.meta') signature = self._get_default_signature(
os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField(expected_signature)) self.assertTrue(signature.HasField(expected_signature))
def testExportMonitor_EstimatorProvidesSignature(self): def testExportMonitor_EstimatorProvidesSignature(self):
...@@ -88,7 +89,7 @@ class ExportTest(test.TestCase): ...@@ -88,7 +89,7 @@ class ExportTest(test.TestCase):
y = 2 * x + 3 y = 2 * x + 3
cont_features = [feature_column.real_valued_column('', dimension=1)] cont_features = [feature_column.real_valued_column('', dimension=1)]
regressor = learn.LinearRegressor(feature_columns=cont_features) regressor = learn.LinearRegressor(feature_columns=cont_features)
export_dir = tempfile.mkdtemp() + 'export/' export_dir = os.path.join(tempfile.mkdtemp(), 'export')
export_monitor = learn.monitors.ExportMonitor( export_monitor = learn.monitors.ExportMonitor(
every_n_steps=1, export_dir=export_dir, exports_to_keep=2) every_n_steps=1, export_dir=export_dir, exports_to_keep=2)
regressor.fit(x, y, steps=10, monitors=[export_monitor]) regressor.fit(x, y, steps=10, monitors=[export_monitor])
...@@ -99,7 +100,7 @@ class ExportTest(test.TestCase): ...@@ -99,7 +100,7 @@ class ExportTest(test.TestCase):
x = np.random.rand(1000) x = np.random.rand(1000)
y = 2 * x + 3 y = 2 * x + 3
cont_features = [feature_column.real_valued_column('', dimension=1)] cont_features = [feature_column.real_valued_column('', dimension=1)]
export_dir = tempfile.mkdtemp() + 'export/' export_dir = os.path.join(tempfile.mkdtemp(), 'export')
export_monitor = learn.monitors.ExportMonitor( export_monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=export_dir, export_dir=export_dir,
...@@ -122,7 +123,7 @@ class ExportTest(test.TestCase): ...@@ -122,7 +123,7 @@ class ExportTest(test.TestCase):
input_feature_key = 'my_example_key' input_feature_key = 'my_example_key'
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=tempfile.mkdtemp() + 'export/', export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn, input_fn=_serving_input_fn,
input_feature_key=input_feature_key, input_feature_key=input_feature_key,
exports_to_keep=2, exports_to_keep=2,
...@@ -140,7 +141,7 @@ class ExportTest(test.TestCase): ...@@ -140,7 +141,7 @@ class ExportTest(test.TestCase):
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=tempfile.mkdtemp() + 'export/', export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn, input_fn=_serving_input_fn,
input_feature_key=input_feature_key, input_feature_key=input_feature_key,
exports_to_keep=2, exports_to_keep=2,
...@@ -165,7 +166,7 @@ class ExportTest(test.TestCase): ...@@ -165,7 +166,7 @@ class ExportTest(test.TestCase):
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=tempfile.mkdtemp() + 'export/', export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn, input_fn=_serving_input_fn,
input_feature_key=input_feature_key, input_feature_key=input_feature_key,
exports_to_keep=2, exports_to_keep=2,
...@@ -187,7 +188,7 @@ class ExportTest(test.TestCase): ...@@ -187,7 +188,7 @@ class ExportTest(test.TestCase):
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=tempfile.mkdtemp() + 'export/', export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn, input_fn=_serving_input_fn,
input_feature_key=input_feature_key, input_feature_key=input_feature_key,
exports_to_keep=2, exports_to_keep=2,
...@@ -210,7 +211,7 @@ class ExportTest(test.TestCase): ...@@ -210,7 +211,7 @@ class ExportTest(test.TestCase):
shape=(1,), minval=0.0, maxval=1000.0) shape=(1,), minval=0.0, maxval=1000.0)
}, None }, None
export_dir = tempfile.mkdtemp() + 'export/' export_dir = os.path.join(tempfile.mkdtemp(), 'export')
monitor = learn.monitors.ExportMonitor( monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=export_dir, export_dir=export_dir,
...@@ -235,7 +236,7 @@ class ExportTest(test.TestCase): ...@@ -235,7 +236,7 @@ class ExportTest(test.TestCase):
y = 2 * x + 3 y = 2 * x + 3
cont_features = [feature_column.real_valued_column('', dimension=1)] cont_features = [feature_column.real_valued_column('', dimension=1)]
regressor = learn.LinearRegressor(feature_columns=cont_features) regressor = learn.LinearRegressor(feature_columns=cont_features)
export_dir = tempfile.mkdtemp() + 'export/' export_dir = os.path.join(tempfile.mkdtemp(), 'export')
export_monitor = learn.monitors.ExportMonitor( export_monitor = learn.monitors.ExportMonitor(
every_n_steps=1, every_n_steps=1,
export_dir=export_dir, export_dir=export_dir,
...@@ -244,10 +245,13 @@ class ExportTest(test.TestCase): ...@@ -244,10 +245,13 @@ class ExportTest(test.TestCase):
regressor.fit(x, y, steps=10, monitors=[export_monitor]) regressor.fit(x, y, steps=10, monitors=[export_monitor])
self.assertTrue(gfile.Exists(export_dir)) self.assertTrue(gfile.Exists(export_dir))
self.assertFalse(saver.checkpoint_exists(export_dir + '00000000/export')) with self.assertRaises(errors.NotFoundError):
self.assertTrue(saver.checkpoint_exists(export_dir + '00000010/export')) saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export'))
self.assertTrue(
saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
# Validate the signature # Validate the signature
signature = self._get_default_signature(export_dir + '00000010/export.meta') signature = self._get_default_signature(
os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField('regression_signature')) self.assertTrue(signature.HasField('regression_signature'))
......
...@@ -33,8 +33,13 @@ from tensorflow.python.util import compat ...@@ -33,8 +33,13 @@ from tensorflow.python.util import compat
def _create_parser(base_dir): def _create_parser(base_dir):
# create a simple parser that pulls the export_version from the directory. # create a simple parser that pulls the export_version from the directory.
def parser(path): def parser(path):
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$", # Modify the path object for RegEx match for Windows Paths
compat.as_str_any(path.path)) if os.name == 'nt':
match = re.match("^" + compat.as_str_any(base_dir).replace('\\','/') + "/(\\d+)$",
compat.as_str_any(path.path).replace('\\','/'))
else:
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
compat.as_str_any(path.path))
if not match: if not match:
return None return None
return path._replace(export_version=int(match.group(1))) return path._replace(export_version=int(match.group(1)))
...@@ -48,13 +53,13 @@ class GcTest(test_util.TensorFlowTestCase): ...@@ -48,13 +53,13 @@ class GcTest(test_util.TensorFlowTestCase):
paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
newest = gc.largest_export_versions(2) newest = gc.largest_export_versions(2)
n = newest(paths) n = newest(paths)
self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)]) self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
def testLargestExportVersionsDoesNotDeleteZeroFolder(self): def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)] paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
newest = gc.largest_export_versions(2) newest = gc.largest_export_versions(2)
n = newest(paths) n = newest(paths)
self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)]) self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
def testModExportVersion(self): def testModExportVersion(self):
paths = [ paths = [
...@@ -62,9 +67,9 @@ class GcTest(test_util.TensorFlowTestCase): ...@@ -62,9 +67,9 @@ class GcTest(test_util.TensorFlowTestCase):
gc.Path("/foo", 9) gc.Path("/foo", 9)
] ]
mod = gc.mod_export_version(2) mod = gc.mod_export_version(2)
self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)]) self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
mod = gc.mod_export_version(3) mod = gc.mod_export_version(3)
self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)]) self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
def testOneOfEveryNExportVersions(self): def testOneOfEveryNExportVersions(self):
paths = [ paths = [
...@@ -73,7 +78,7 @@ class GcTest(test_util.TensorFlowTestCase): ...@@ -73,7 +78,7 @@ class GcTest(test_util.TensorFlowTestCase):
gc.Path("/foo", 8), gc.Path("/foo", 33) gc.Path("/foo", 8), gc.Path("/foo", 33)
] ]
one_of = gc.one_of_every_n_export_versions(3) one_of = gc.one_of_every_n_export_versions(3)
self.assertEquals( self.assertEqual(
one_of(paths), [ one_of(paths), [
gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8), gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8),
gc.Path("/foo", 33) gc.Path("/foo", 33)
...@@ -84,14 +89,14 @@ class GcTest(test_util.TensorFlowTestCase): ...@@ -84,14 +89,14 @@ class GcTest(test_util.TensorFlowTestCase):
# Test that here. # Test that here.
paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)] paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
one_of = gc.one_of_every_n_export_versions(3) one_of = gc.one_of_every_n_export_versions(3)
self.assertEquals(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)]) self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
def testUnion(self): def testUnion(self):
paths = [] paths = []
for i in xrange(10): for i in xrange(10):
paths.append(gc.Path("/foo", i)) paths.append(gc.Path("/foo", i))
f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
self.assertEquals( self.assertEqual(
f(paths), [ f(paths), [
gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6),
gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9) gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)
...@@ -103,9 +108,9 @@ class GcTest(test_util.TensorFlowTestCase): ...@@ -103,9 +108,9 @@ class GcTest(test_util.TensorFlowTestCase):
gc.Path("/foo", 9) gc.Path("/foo", 9)
] ]
mod = gc.negation(gc.mod_export_version(2)) mod = gc.negation(gc.mod_export_version(2))
self.assertEquals(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
mod = gc.negation(gc.mod_export_version(3)) mod = gc.negation(gc.mod_export_version(3))
self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)]) self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
def testPathsWithParse(self): def testPathsWithParse(self):
base_dir = os.path.join(test.get_temp_dir(), "paths_parse") base_dir = os.path.join(test.get_temp_dir(), "paths_parse")
...@@ -115,7 +120,7 @@ class GcTest(test_util.TensorFlowTestCase): ...@@ -115,7 +120,7 @@ class GcTest(test_util.TensorFlowTestCase):
# add a base_directory to ignore # add a base_directory to ignore
gfile.MakeDirs(os.path.join(base_dir, "ignore")) gfile.MakeDirs(os.path.join(base_dir, "ignore"))
self.assertEquals( self.assertEqual(
gc.get_paths(base_dir, _create_parser(base_dir)), gc.get_paths(base_dir, _create_parser(base_dir)),
[ [
gc.Path(os.path.join(base_dir, "0"), 0), gc.Path(os.path.join(base_dir, "0"), 0),
......
...@@ -301,7 +301,12 @@ class Exporter(object): ...@@ -301,7 +301,12 @@ class Exporter(object):
if exports_to_keep: if exports_to_keep:
# create a simple parser that pulls the export_version from the directory. # create a simple parser that pulls the export_version from the directory.
def parser(path): def parser(path):
match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) if os.name == 'nt':
match = re.match("^" + export_dir_base.replace('\\','/') + "/(\\d{8})$",
path.path.replace('\\','/'))
else:
match = re.match("^" + export_dir_base + "/(\\d{8})$",
path.path)
if not match: if not match:
return None return None
return path._replace(export_version=int(match.group(1))) return path._replace(export_version=int(match.group(1)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册