提交 90789091 编写于 作者: M Martin Wicke

Merge pull request #2649 from martinwicke/branch_124012080

Branch 124012080
......@@ -92,6 +92,8 @@ filegroup(
"//tensorflow/contrib/quantization:all_files",
"//tensorflow/contrib/quantization/kernels:all_files",
"//tensorflow/contrib/quantization/tools:all_files",
"//tensorflow/contrib/session_bundle:all_files",
"//tensorflow/contrib/session_bundle/example:all_files",
"//tensorflow/contrib/skflow:all_files",
"//tensorflow/contrib/slim:all_files",
"//tensorflow/contrib/tensor_forest:all_files",
......
......@@ -82,9 +82,12 @@ string DebugString(const Tensor& x, const Tensor& y) {
CHECK_EQ(y.NumElements(), 2);
auto x_flat = x.flat<float>();
auto y_flat = y.flat<float>();
const float lambda = y_flat(0) / x_flat(0);
// Compute an estimate of the eigenvalue via
// (x' A x) / (x' x) = (x' y) / (x' x)
// and exploit the fact that x' x = 1 by assumption
Eigen::Tensor<float, 0, Eigen::RowMajor> lambda = (x_flat * y_flat).sum();
return strings::Printf("lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]",
lambda, x_flat(0), x_flat(1), y_flat(0), y_flat(1));
lambda(), x_flat(0), x_flat(1), y_flat(0), y_flat(1));
}
void ConcurrentSteps(const Options* opts, int session_index) {
......@@ -106,7 +109,11 @@ void ConcurrentSteps(const Options* opts, int session_index) {
step_threads.Schedule([&session, opts, session_index, step]() {
// Randomly initialize the input.
Tensor x(DT_FLOAT, TensorShape({2, 1}));
x.flat<float>().setRandom();
auto x_flat = x.flat<float>();
x_flat.setRandom();
Eigen::Tensor<float, 0, Eigen::RowMajor> inv_norm =
x_flat.square().sum().sqrt().inverse();
x_flat = x_flat * inv_norm();
// Iterations.
std::vector<Tensor> outputs;
......
......@@ -61,16 +61,16 @@ class Chi2Test(tf.test.TestCase):
df_v = np.array([1., 3, 5], dtype=np.float64)
expected_mean = stats.chi2.mean(df_v)
chi2 = tf.contrib.distributions.Chi2(df=df_v)
self.assertEqual(chi2.mean.get_shape(), (3,))
self.assertAllClose(chi2.mean.eval(), expected_mean)
self.assertEqual(chi2.mean().get_shape(), (3,))
self.assertAllClose(chi2.mean().eval(), expected_mean)
def testChi2Variance(self):
with tf.Session():
df_v = np.array([1., 3, 5], np.float64)
expected_variances = stats.chi2.var(df_v)
chi2 = tf.contrib.distributions.Chi2(df=df_v)
self.assertEqual(chi2.variance.get_shape(), (3,))
self.assertAllClose(chi2.variance.eval(), expected_variances)
self.assertEqual(chi2.variance().get_shape(), (3,))
self.assertAllClose(chi2.variance().eval(), expected_variances)
def testChi2Entropy(self):
with tf.Session():
......
......@@ -170,7 +170,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
counts = np.zeros((3), dtype=np.float32)
counts[class_num] = 1
dist = tf.contrib.distributions.DirichletMultinomial(1., alpha)
mean = dist.mean.eval()
mean = dist.mean().eval()
pmf = dist.pmf(counts).eval()
self.assertAllClose(mean[class_num], pmf)
......@@ -192,8 +192,8 @@ class DirichletMultinomialTest(tf.test.TestCase):
dist1 = tf.contrib.distributions.DirichletMultinomial(1., alpha)
dist2 = tf.contrib.distributions.DirichletMultinomial(2., alpha)
mean1 = dist1.mean.eval()
mean2 = dist2.mean.eval()
mean1 = dist1.mean().eval()
mean2 = dist2.mean().eval()
self.assertAllClose(mean2[class_num], 2 * mean1[class_num])
self.assertTupleEqual((3,), mean1.shape)
......
......@@ -61,16 +61,16 @@ class ExponentialTest(tf.test.TestCase):
lam_v = np.array([1.0, 4.0, 2.5])
expected_mean = stats.expon.mean(scale=1 / lam_v)
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
self.assertEqual(exponential.mean.get_shape(), (3,))
self.assertAllClose(exponential.mean.eval(), expected_mean)
self.assertEqual(exponential.mean().get_shape(), (3,))
self.assertAllClose(exponential.mean().eval(), expected_mean)
def testExponentialVariance(self):
with tf.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_variance = stats.expon.var(scale=1 / lam_v)
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
self.assertEqual(exponential.variance.get_shape(), (3,))
self.assertAllClose(exponential.variance.eval(), expected_variance)
self.assertEqual(exponential.variance().get_shape(), (3,))
self.assertAllClose(exponential.variance().eval(), expected_variance)
def testExponentialEntropy(self):
with tf.Session():
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for initializers."""
from __future__ import absolute_import
from __future__ import division
......@@ -26,18 +25,18 @@ import tensorflow as tf
class GammaTest(tf.test.TestCase):
def testGammaShape(self):
with tf.Session():
with self.test_session():
alpha = tf.constant([3.0] * 5)
beta = tf.constant(11.0)
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
self.assertEqual(gamma.batch_shape().eval(), (5,))
self.assertEqual(gamma.get_batch_shape(), tf.TensorShape([5]))
self.assertEqual(gamma.event_shape().eval(), 1)
self.assertAllEqual(gamma.event_shape().eval(), [])
self.assertEqual(gamma.get_event_shape(), tf.TensorShape([]))
def testGammaLogPDF(self):
with tf.Session():
with self.test_session():
batch_size = 6
alpha = tf.constant([2.0] * batch_size)
beta = tf.constant([3.0] * batch_size)
......@@ -55,7 +54,7 @@ class GammaTest(tf.test.TestCase):
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensional(self):
with tf.Session():
with self.test_session():
batch_size = 6
alpha = tf.constant([[2.0, 4.0]] * batch_size)
beta = tf.constant([[3.0, 4.0]] * batch_size)
......@@ -75,7 +74,7 @@ class GammaTest(tf.test.TestCase):
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensionalBroadcasting(self):
with tf.Session():
with self.test_session():
batch_size = 6
alpha = tf.constant([[2.0, 4.0]] * batch_size)
beta = tf.constant(3.0)
......@@ -95,7 +94,7 @@ class GammaTest(tf.test.TestCase):
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaCDF(self):
with tf.Session():
with self.test_session():
batch_size = 6
alpha = tf.constant([2.0] * batch_size)
beta = tf.constant([3.0] * batch_size)
......@@ -111,25 +110,45 @@ class GammaTest(tf.test.TestCase):
self.assertAllClose(cdf.eval(), expected_cdf)
def testGammaMean(self):
with tf.Session():
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.mean.get_shape(), (3,))
self.assertAllClose(gamma.mean.eval(), expected_means)
self.assertEqual(gamma.mean().get_shape(), (3,))
self.assertAllClose(gamma.mean().eval(), expected_means)
def testGammaMode(self):
with self.test_session():
# Mode will not be defined for the first entry.
alpha_v = np.array([0.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
expected_modes = (alpha_v - 1) / beta_v
expected_modes[0] = np.nan
self.assertEqual(gamma.mode().get_shape(), (3,))
self.assertAllClose(gamma.mode().eval(), expected_modes)
def testGammaVariance(self):
with tf.Session():
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.variance.get_shape(), (3,))
self.assertAllClose(gamma.variance.eval(), expected_variances)
self.assertEqual(gamma.variance().get_shape(), (3,))
self.assertAllClose(gamma.variance().eval(), expected_variances)
def testGammaStd(self):
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
expected_std = stats.gamma.std(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.std().get_shape(), (3,))
self.assertAllClose(gamma.std().eval(), expected_std)
def testGammaEntropy(self):
with tf.Session():
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
......@@ -138,12 +157,12 @@ class GammaTest(tf.test.TestCase):
self.assertAllClose(gamma.entropy().eval(), expected_entropy)
def testGammaNonPositiveInitializationParamsRaises(self):
with tf.Session():
with self.test_session():
alpha_v = tf.constant(0.0, name='alpha')
beta_v = tf.constant(1.0, name='beta')
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
with self.assertRaisesOpError('alpha'):
gamma.mean.eval()
gamma.mean().eval()
if __name__ == '__main__':
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import math
import numpy as np
from scipy import stats
import tensorflow as tf
......@@ -31,13 +32,9 @@ class NormalTest(tf.test.TestCase):
batch_size = 6
mu = tf.constant([3.0] * batch_size)
sigma = tf.constant([math.sqrt(10.0)] * batch_size)
mu_v = 3.0
sigma_v = np.sqrt(10.0)
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
expected_log_pdf = np.log(
1 / np.sqrt(2 * np.pi) / sigma_v
* np.exp(-1.0 / (2 * sigma_v**2) * (x - mu_v)**2))
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
log_pdf = normal.log_pdf(x)
self.assertAllClose(expected_log_pdf, log_pdf.eval())
......@@ -58,13 +55,9 @@ class NormalTest(tf.test.TestCase):
batch_size = 6
mu = tf.constant([[3.0, -3.0]] * batch_size)
sigma = tf.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
mu_v = np.array([3.0, -3.0])
sigma_v = np.array([np.sqrt(10.0), np.sqrt(15.0)])
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
expected_log_pdf = np.log(
1 / np.sqrt(2 * np.pi) / sigma_v
* np.exp(-1.0 / (2 * sigma_v**2) * (x - mu_v)**2))
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
log_pdf = normal.log_pdf(x)
log_pdf_values = log_pdf.eval()
......@@ -89,15 +82,10 @@ class NormalTest(tf.test.TestCase):
batch_size = 6
mu = tf.constant([3.0] * batch_size)
sigma = tf.constant([math.sqrt(10.0)] * batch_size)
mu_v = 3.0
sigma_v = np.sqrt(10.0)
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
erf_fn = np.vectorize(math.erf)
# From Wikipedia
expected_cdf = 0.5 * (1.0 + erf_fn((x - mu_v)/(sigma_v*np.sqrt(2))))
expected_cdf = stats.norm(mu.eval(), sigma.eval()).cdf(x)
cdf = normal.cdf(x)
self.assertAllClose(expected_cdf, cdf.eval())
......@@ -106,21 +94,74 @@ class NormalTest(tf.test.TestCase):
self.assertAllEqual(normal.get_batch_shape(), cdf.get_shape())
self.assertAllEqual(normal.get_batch_shape(), cdf.eval().shape)
def testNormalEntropyWithScalarInputs(self):
# Scipy.stats.norm cannot deal with the shapes in the other test.
with self.test_session():
mu_v = 2.34
sigma_v = 4.56
normal = tf.contrib.distributions.Normal(mu=mu_v, sigma=sigma_v)
# scipy.stats.norm cannot deal with these shapes.
expected_entropy = stats.norm(mu_v, sigma_v).entropy()
entropy = normal.entropy()
self.assertAllClose(expected_entropy, entropy.eval())
self.assertAllEqual(normal.batch_shape().eval(), entropy.get_shape())
self.assertAllEqual(normal.batch_shape().eval(), entropy.eval().shape)
self.assertAllEqual(normal.get_batch_shape(), entropy.get_shape())
self.assertAllEqual(normal.get_batch_shape(), entropy.eval().shape)
def testNormalEntropy(self):
with self.test_session():
mu_v = np.array([1.0, 1.0, 1.0])
sigma_v = np.array([[1.0, 2.0, 3.0]]).T
normal = tf.contrib.distributions.Normal(mu=mu_v, sigma=sigma_v)
# scipy.stats.norm cannot deal with these shapes.
sigma_broadcast = mu_v * sigma_v
expected_entropy = 0.5 * np.log(2*np.pi*np.exp(1)*sigma_broadcast**2)
entropy = normal.entropy()
self.assertAllClose(expected_entropy, entropy.eval())
np.testing.assert_allclose(expected_entropy, entropy.eval())
self.assertAllEqual(normal.batch_shape().eval(), entropy.get_shape())
self.assertAllEqual(normal.batch_shape().eval(), entropy.eval().shape)
self.assertAllEqual(normal.get_batch_shape(), entropy.get_shape())
self.assertAllEqual(normal.get_batch_shape(), entropy.eval().shape)
def testNormalMeanAndMode(self):
with self.test_session():
# Mu will be broadcast to [7, 7, 7].
mu = [7.]
sigma = [11., 12., 13.]
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
self.assertAllEqual((3,), normal.mean().get_shape())
self.assertAllEqual([7., 7, 7], normal.mean().eval())
self.assertAllEqual((3,), normal.mode().get_shape())
self.assertAllEqual([7., 7, 7], normal.mode().eval())
def testNormalVariance(self):
with self.test_session():
# sigma will be broadcast to [7, 7, 7]
mu = [1., 2., 3.]
sigma = [7.]
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
self.assertAllEqual((3,), normal.variance().get_shape())
self.assertAllEqual([49., 49, 49], normal.variance().eval())
def testNormalStandardDeviation(self):
with self.test_session():
# sigma will be broadcast to [7, 7, 7]
mu = [1., 2., 3.]
sigma = [7.]
normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
self.assertAllEqual((3,), normal.std().get_shape())
self.assertAllEqual([7., 7, 7], normal.std().eval())
def testNormalSample(self):
with self.test_session():
mu = tf.constant(3.0)
......@@ -183,9 +224,8 @@ class NormalTest(tf.test.TestCase):
mu=[1.],
sigma=[-5.],
name='G')
with self.assertRaisesOpError(
r'should contain only positive values'):
normal.mean.eval()
with self.assertRaisesOpError('Condition x > 0 did not hold'):
normal.mean().eval()
def testNormalShape(self):
with self.test_session():
......@@ -195,7 +235,7 @@ class NormalTest(tf.test.TestCase):
self.assertEqual(normal.batch_shape().eval(), [5])
self.assertEqual(normal.get_batch_shape(), tf.TensorShape([5]))
self.assertEqual(normal.event_shape().eval(), 1)
self.assertAllEqual(normal.event_shape().eval(), [])
self.assertEqual(normal.get_event_shape(), tf.TensorShape([]))
def testNormalShapeWithPlaceholders(self):
......@@ -207,11 +247,12 @@ class NormalTest(tf.test.TestCase):
# get_batch_shape should return an "<unknown>" tensor.
self.assertEqual(normal.get_batch_shape(), tf.TensorShape(None))
self.assertEqual(normal.get_event_shape(), ())
self.assertEqual(normal.event_shape().eval(), 1)
self.assertAllEqual(normal.event_shape().eval(), [])
self.assertAllEqual(
sess.run(normal.batch_shape(),
feed_dict={mu: 5.0, sigma: [1.0, 2.0]}),
[2])
if __name__ == '__main__':
tf.test.main()
......@@ -166,8 +166,8 @@ class StudentTTest(tf.test.TestCase):
def testBroadcastingParams(self):
def _check(student):
self.assertEqual(student.mean.get_shape(), (3,))
self.assertEqual(student.variance.get_shape(), (3,))
self.assertEqual(student.mean().get_shape(), (3,))
self.assertEqual(student.variance().get_shape(), (3,))
self.assertEqual(student.entropy().get_shape(), (3,))
self.assertEqual(student.log_pdf(2.).get_shape(), (3,))
self.assertEqual(student.pdf(2.).get_shape(), (3,))
......@@ -228,22 +228,60 @@ class StudentTTest(tf.test.TestCase):
_check2d_rows(tf.contrib.distributions.StudentT(
df=7., mu=3., sigma=[[2.], [3.], [4.]]))
def testMeanVar(self):
def testMean(self):
with tf.Session():
mu = [-2, 0., 1., 3.3, 4.4]
student = tf.contrib.distributions.StudentT(
df=[1., 2., 3., 5., 7.],
mu=np.exp(1, dtype=np.float32),
df=[0.5, 1., 3., 5., 7.],
mu=mu,
sigma=[5., 4., 3., 2., 1.])
# Test broadcast of mu across shape of df/sigma
mean = student.mean.eval()
self.assertAllClose([np.exp(1, dtype=np.float32)] * 5, mean)
var = student.variance.eval()
# loc does not effect variance, so we use 0.
self.assertAllClose([stats.t.var(1., loc=0., scale=5.),
stats.t.var(2., loc=0., scale=4.),
stats.t.var(3., loc=0., scale=3.),
stats.t.var(5., loc=0., scale=2.),
stats.t.var(7., loc=0., scale=1.)], var)
mean = student.mean().eval()
self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
def testVariance(self):
with tf.Session():
df = [0.5, 1., 3., 5., 7.]
mu = [-2, 0., 1., 3.3, 4.4]
sigma = [5., 4., 3., 2., 1.]
student = tf.contrib.distributions.StudentT(df=df, mu=mu, sigma=sigma)
# Test broadcast of mu across shape of df/sigma
var = student.variance().eval()
# scipy uses inf rather than nan here. Assert we use NaN, then replace
# with infinity to compare to scipy.
self.assertFalse(np.isinf(var).any())
var[np.isnan(var)] = np.inf
expected_var = [
stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)]
self.assertAllClose(expected_var, var)
def testStd(self):
with tf.Session():
df = [0.5, 1., 3., 5., 7.]
mu = [-2, 0., 1., 3.3, 4.4]
sigma = [5., 4., 3., 2., 1.]
student = tf.contrib.distributions.StudentT(df=df, mu=mu, sigma=sigma)
# Test broadcast of mu across shape of df/sigma
std = student.std().eval()
# scipy uses inf rather than nan here. Assert we use NaN, then replace
# with infinity to compare to scipy.
self.assertFalse(np.isinf(std).any())
std[np.isnan(std)] = np.inf
expected_std = [
stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)]
self.assertAllClose(expected_std, std)
def testMode(self):
with tf.Session():
student = tf.contrib.distributions.StudentT(
df=[0.5, 1., 3],
mu=[-1, 0., 1],
sigma=[5., 4., 3.])
# Test broadcast of mu across shape of df/sigma
mode = student.mode().eval()
self.assertAllClose([-1., 0, 1], mode)
def testPdfOfSample(self):
with tf.Session() as sess:
......@@ -251,10 +289,10 @@ class StudentTTest(tf.test.TestCase):
num = 20000
samples = student.sample(num, seed=137)
pdfs = student.pdf(samples)
mean = student.mean
mean_pdf = student.pdf(student.mean)
mean = student.mean()
mean_pdf = student.pdf(student.mean())
sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run(
[samples, pdfs, student.mean, mean_pdf])
[samples, pdfs, student.mean(), mean_pdf])
self.assertEqual(samples.get_shape(), (num,))
self.assertEqual(pdfs.get_shape(), (num,))
self.assertEqual(mean.get_shape(), ())
......@@ -305,7 +343,7 @@ class StudentTTest(tf.test.TestCase):
sigma=1.,
name='S')
with self.assertRaisesOpError(r'Condition x > 0 did not hold'):
student.mean.eval()
student.mean().eval()
def testNegativeScaleFails(self):
with tf.Session():
......@@ -314,7 +352,7 @@ class StudentTTest(tf.test.TestCase):
sigma=[[3.], [-2.]],
name='S')
with self.assertRaisesOpError(r'Condition x > 0 did not hold'):
student.mean.eval()
student.mean().eval()
if __name__ == '__main__':
......
......@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from scipy import stats
import tensorflow as tf
......@@ -31,7 +32,7 @@ class UniformTest(tf.test.TestCase):
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
self.assertAllClose(a, uniform.a.eval())
self.assertAllClose(b, uniform.b.eval())
self.assertAllClose(b - a, uniform.range.eval())
self.assertAllClose(b - a, uniform.range().eval())
def testUniformPDF(self):
with self.test_session():
......@@ -66,7 +67,7 @@ class UniformTest(tf.test.TestCase):
self.assertEqual(uniform.batch_shape().eval(), (5,))
self.assertEqual(uniform.get_batch_shape(), tf.TensorShape([5]))
self.assertEqual(uniform.event_shape().eval(), 1)
self.assertAllEqual(uniform.event_shape().eval(), [])
self.assertEqual(uniform.get_event_shape(), tf.TensorShape([]))
def testUniformPDFWithScalarEndpoint(self):
......@@ -172,13 +173,29 @@ class UniformTest(tf.test.TestCase):
self.assertAllClose(sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2,
atol=1e-2)
def testUniformMeanAndVariance(self):
def testUniformMean(self):
with self.test_session():
a = 10.0
b = 100.0
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
self.assertAllClose(uniform.variance.eval(), (b - a)**2 / 12)
self.assertAllClose(uniform.mean.eval(), (b + a) / 2)
s_uniform = stats.uniform(loc=a, scale=b-a)
self.assertAllClose(uniform.mean().eval(), s_uniform.mean())
def testUniformVariance(self):
with self.test_session():
a = 10.0
b = 100.0
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
s_uniform = stats.uniform(loc=a, scale=b-a)
self.assertAllClose(uniform.variance().eval(), s_uniform.var())
def testUniformStd(self):
with self.test_session():
a = 10.0
b = 100.0
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
s_uniform = stats.uniform(loc=a, scale=b-a)
self.assertAllClose(uniform.std().eval(), s_uniform.std())
def testUniformNans(self):
with self.test_session():
......
......@@ -19,7 +19,7 @@ from __future__ import print_function
# pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops.distribution import DiscreteDistribution
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
......@@ -73,7 +73,7 @@ def _log_combinations(n, counts, name='log_combinations'):
return total_permutations - redundant_permutations
class DirichletMultinomial(DiscreteDistribution):
class DirichletMultinomial(distribution.DiscreteDistribution):
"""DirichletMultinomial mixture distribution.
This distribution is parameterized by a vector `alpha` of concentration
......@@ -195,19 +195,13 @@ class DirichletMultinomial(DiscreteDistribution):
self._allow_arbitrary_counts = allow_arbitrary_counts
alpha_sum = math_ops.reduce_sum(self._alpha,
reduction_indices=[-1],
keep_dims=False)
self._alpha_sum = math_ops.reduce_sum(
self._alpha, reduction_indices=[-1], keep_dims=False)
mean = self._alpha / array_ops.expand_dims(alpha_sum, -1)
self._mean = array_ops.expand_dims(n, -1) * mean
self._get_batch_shape = self._alpha_sum.get_shape()
self._batch_shape = array_ops.shape(alpha_sum)
self._get_batch_shape = alpha_sum.get_shape()
self._event_shape = array_ops.reverse(
array_ops.shape(self._mean), [True])[0]
self._get_event_shape = self._mean.get_shape().with_rank_at_least(1)[-1:]
# event shape depends only on alpha, not "n".
self._get_event_shape = self._alpha.get_shape().with_rank_at_least(1)[-1:]
@property
def n(self):
......@@ -227,12 +221,17 @@ class DirichletMultinomial(DiscreteDistribution):
@property
def dtype(self):
"""dtype of samples from this distribution."""
return self._mean.dtype
return self._alpha.dtype
@property
def mean(self):
def mean(self, name='mean'):
"""Class means for every batch member."""
return self._mean
alpha = self._alpha
alpha_sum = self._alpha_sum
n = self._n
with ops.name_scope(self.name):
with ops.op_scope([alpha, alpha_sum, n], name):
mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
return array_ops.expand_dims(n, -1) * mean_no_n
def batch_shape(self, name='batch_shape'):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
......@@ -247,8 +246,8 @@ class DirichletMultinomial(DiscreteDistribution):
`Tensor` `batch_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([], name):
return self._batch_shape
with ops.op_scope([self._alpha_sum], name):
return array_ops.shape(self._alpha_sum)
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
......@@ -270,8 +269,8 @@ class DirichletMultinomial(DiscreteDistribution):
`Tensor` `event_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([], name):
return self._event_shape
with ops.op_scope([self._alpha], name):
return array_ops.reverse(array_ops.shape(self._alpha), [True])[0]
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
......@@ -283,11 +282,11 @@ class DirichletMultinomial(DiscreteDistribution):
"""
return self._get_event_shape
def cdf(self, x):
def cdf(self, x, name='cdf'):
raise NotImplementedError(
'DirichletMultinomial does not have a well-defined cdf.')
def log_cdf(self, x):
def log_cdf(self, x, name='log_cdf'):
raise NotImplementedError(
'DirichletMultinomial does not have a well-defined cdf.')
......@@ -356,9 +355,7 @@ class DirichletMultinomial(DiscreteDistribution):
Returns:
Probabilities for each record, shape `[N1,...,Nn]`.
"""
with ops.name_scope(self.name):
with ops.op_scope([], name):
return super(DirichletMultinomial, self).pmf(counts, name=name)
return super(DirichletMultinomial, self).pmf(counts, name=name)
def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
......
......@@ -83,7 +83,7 @@ class BaseDistribution(object):
# `event_shape` is `TensorShape([])`.
event_shape = u.get_event_shape()
# `event_shape_t` is a `Tensor` which will evaluate to a scalar 1.
# `event_shape_t` is a `Tensor` which will evaluate to [].
event_shape_t = u.event_shape
# Sampling returns a sample per distribution. `samples` has shape
......@@ -112,15 +112,17 @@ class BaseDistribution(object):
@abc.abstractproperty
def name(self):
"""Name to prepend to all ops."""
# return self._name.
pass
@abc.abstractproperty
def dtype(self):
"""dtype of samples from this distribution."""
# return self._dtype
pass
@abc.abstractmethod
def event_shape(self, name=None):
def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args:
......@@ -129,6 +131,10 @@ class BaseDistribution(object):
Returns:
`Tensor` `event_shape`
"""
# For scalar distributions, constant([], int32)
# with ops.name_scope(self.name):
# with ops.op_scope([tensor_arguments], name):
# Your code here
pass
@abc.abstractmethod
......@@ -137,10 +143,11 @@ class BaseDistribution(object):
Same meaning as `event_shape`. May be only partially defined.
"""
# return self._event_shape
pass
@abc.abstractmethod
def batch_shape(self, name=None):
def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of
......@@ -152,6 +159,9 @@ class BaseDistribution(object):
Returns:
`Tensor` `batch_shape`
"""
# with ops.name_scope(self.name):
# with ops.op_scope([tensor_arguments], name):
# Your code here
pass
@abc.abstractmethod
......@@ -162,7 +172,7 @@ class BaseDistribution(object):
"""
pass
def sample(self, n, seed=None, name=None):
def sample(self, n, seed=None, name="sample"):
"""Generate `n` samples.
Args:
......@@ -178,23 +188,39 @@ class BaseDistribution(object):
def cdf(self, value, name="cdf"):
"""Cumulative distribution function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([value], name):
value = ops.convert_to_tensor(value)
return math_ops.exp(self.log_cdf(value))
def log_cdf(self, value, name="log_cdf"):
"""Log CDF."""
raise NotImplementedError("log_cdf is not implemented")
def entropy(self, name=None):
def entropy(self, name="entropy"):
"""Entropy of the distribution in nats."""
raise NotImplementedError("entropy not implemented")
@property
def mean(self):
def mean(self, name="mean"):
"""Mean of the distribution."""
# Set to np.nan if parameters mean it is undefined/infinite.
raise NotImplementedError("mean not implemented")
def mode(self, name="mode"):
"""Mode of the distribution."""
# Set to np.nan if parameters mean it is undefined/infinite.
raise NotImplementedError("mode not implemented")
def std(self, name="std"):
"""Standard deviation of the distribution."""
# Set to np.nan if parameters mean it is undefined/infinite.
raise NotImplementedError("std not implemented")
def variance(self, name="variance"):
"""Variance of the distribution."""
# Set to np.nan if parameters mean it is undefined/infinite.
raise NotImplementedError("variance not implemented")
class ContinuousDistribution(BaseDistribution):
"""Base class for continuous probability distributions.
......@@ -223,19 +249,25 @@ class ContinuousDistribution(BaseDistribution):
@abc.abstractmethod
def pdf(self, value, name="pdf"):
"""Probability density function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([value], name):
value = ops.convert_to_tensor(value)
return math_ops.exp(self.log_pdf(value))
@abc.abstractmethod
def log_pdf(self, value, name="log_pdf"):
"""Log of the probability density function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([value], name):
value = ops.convert_to_tensor(value)
return math_ops.log(self.pdf(value))
def log_likelihood(self, value, name="log_likelihood"):
"""Log likelihood of this distribution (same as log_pdf)."""
with ops.name_scope(self.name):
with ops.op_scope([value], name):
return self.log_pdf(value)
class DiscreteDistribution(BaseDistribution):
"""Base class for discrete probability distributions.
......@@ -253,15 +285,21 @@ class DiscreteDistribution(BaseDistribution):
@abc.abstractmethod
def pmf(self, value, name="pmf"):
"""Probability mass function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([value], name):
value = ops.convert_to_tensor(value)
return math_ops.exp(self.log_pmf(value))
@abc.abstractmethod
def log_pmf(self, value, name="log_pmf"):
"""Log of the probability mass function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([value], name):
value = ops.convert_to_tensor(value)
return math_ops.log(self.pmf(value))
def log_likelihood(self, value, name="log_likelihood"):
"""Log likelihood of this distribution (same as log_pmf)."""
with ops.name_scope(self.name):
with ops.op_scope([value], name):
return self.log_pmf(value)
......@@ -18,8 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long
import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
......@@ -29,7 +32,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
class Gamma(ContinuousDistribution):
class Gamma(distribution.ContinuousDistribution):
"""The `Gamma` distribution with parameter alpha and beta.
The parameters are the shape and inverse scale parameters alpha, beta.
......@@ -78,14 +81,9 @@ class Gamma(ContinuousDistribution):
beta = array_ops.identity(beta, name="beta")
contrib_tensor_util.assert_same_float_dtype((alpha, beta))
self._broadcast_tensor = alpha + beta
with ops.name_scope("mean"):
self._mean = alpha / beta
with ops.name_scope("variance"):
self._variance = alpha / math_ops.square(beta)
self._get_batch_shape = self._mean.get_shape()
self._get_batch_shape = self._broadcast_tensor.get_shape()
self._get_event_shape = tensor_shape.TensorShape([])
self._alpha = alpha
......@@ -125,7 +123,8 @@ class Gamma(ContinuousDistribution):
`Tensor` `batch_shape`
"""
with ops.name_scope(self.name):
return array_ops.shape(self._mean, name=name)
with ops.op_scope([self._broadcast_tensor], name):
return array_ops.shape(self._broadcast_tensor)
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
......@@ -147,7 +146,8 @@ class Gamma(ContinuousDistribution):
`Tensor` `event_shape`
"""
with ops.name_scope(self.name):
return constant_op.constant(1, name=name)
with ops.op_scope([], name):
return constant_op.constant([], dtype=dtypes.int32)
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
......@@ -159,15 +159,34 @@ class Gamma(ContinuousDistribution):
"""
return self._get_event_shape
@property
def mean(self):
def mean(self, name="mean"):
"""Mean of each batch member."""
return self._mean
with ops.name_scope(self.name):
with ops.op_scope([self._alpha, self._beta], name):
return self._alpha / self._beta
@property
def variance(self):
def mode(self, name="mode"):
"""Mode of each batch member. Defined only if alpha >= 1."""
alpha = self._alpha
beta = self._beta
with ops.name_scope(self.name):
with ops.op_scope([alpha, beta], name):
alpha_ge_1 = alpha >= 1.0
mode_if_defined = (alpha - 1.0) / beta
nan = np.nan * self._ones()
return math_ops.select(alpha_ge_1, mode_if_defined, nan)
def variance(self, name="variance"):
"""Variance of each batch member."""
return self._variance
with ops.name_scope(self.name):
with ops.op_scope([self._alpha, self._beta], name):
return self._alpha / math_ops.square(self._beta)
def std(self, name="std"):
"""Standard deviation of this distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._alpha, self._beta], name):
return math_ops.sqrt(self._alpha) / self._beta
def log_pdf(self, x, name="log_pdf"):
"""Log pdf of observations in `x` under these Gamma distribution(s).
......@@ -182,8 +201,8 @@ class Gamma(ContinuousDistribution):
Raises:
TypeError: if `x` and `alpha` are different dtypes.
"""
with ops.op_scope([self._alpha, self._beta, x], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self._alpha, self._beta, x], name):
alpha = self._alpha
beta = self._beta
x = ops.convert_to_tensor(x)
......@@ -208,8 +227,9 @@ class Gamma(ContinuousDistribution):
Raises:
TypeError: if `x` and `alpha` are different dtypes.
"""
with ops.name_scope(name):
return math_ops.exp(self.log_pdf(x, name))
with ops.name_scope(self.name):
with ops.op_scope([], name):
return math_ops.exp(self.log_pdf(x))
def log_cdf(self, x, name="log_cdf"):
"""Log CDF of observations `x` under these Gamma distribution(s).
......@@ -221,8 +241,8 @@ class Gamma(ContinuousDistribution):
Returns:
log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`.
"""
with ops.op_scope([self._alpha, self._beta, x], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self._alpha, self._beta, x], name):
x = ops.convert_to_tensor(x)
x = control_flow_ops.with_dependencies(
[check_ops.assert_positive(x)], x)
......@@ -242,8 +262,8 @@ class Gamma(ContinuousDistribution):
Returns:
cdf: tensor of dtype `dtype`, the CDFs of `x`.
"""
with ops.op_scope([self._alpha, self._beta, x], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self._alpha, self._beta, x], name):
return math_ops.igamma(self._alpha, self._beta * x)
def entropy(self, name="entropy"):
......@@ -264,8 +284,8 @@ class Gamma(ContinuousDistribution):
Returns:
entropy: tensor of dtype `dtype`, the entropy.
"""
with ops.op_scope([self.alpha, self._beta], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self.alpha, self._beta], name):
alpha = self._alpha
beta = self._beta
return (alpha - math_ops.log(beta) + math_ops.lgamma(alpha) +
......@@ -274,3 +294,6 @@ class Gamma(ContinuousDistribution):
@property
def is_reparameterized(self):
return False
def _ones(self):
return array_ops.ones_like(self._alpha + self._beta, dtype=self.dtype)
......@@ -20,26 +20,20 @@ from __future__ import print_function
import math
from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
# TODO(ebrevdo): Use asserts contrib module when ready
def _assert_all_positive(x):
return logging_ops.Assert(
math_ops.reduce_all(x > 0),
["Tensor %s should contain only positive values: " % x.name, x])
class Normal(ContinuousDistribution):
class Normal(distribution.ContinuousDistribution):
"""The scalar Normal distribution with mean and stddev parameters mu, sigma.
#### Mathematical details
......@@ -103,7 +97,7 @@ class Normal(ContinuousDistribution):
with ops.op_scope([mu, sigma], name):
mu = ops.convert_to_tensor(mu)
sigma = ops.convert_to_tensor(sigma)
with ops.control_dependencies([_assert_all_positive(sigma)]):
with ops.control_dependencies([check_ops.assert_positive(sigma)]):
self._name = name
self._mu = array_ops.identity(mu, name="mu")
self._sigma = array_ops.identity(sigma, name="sigma")
......@@ -132,8 +126,9 @@ class Normal(ContinuousDistribution):
Returns:
`Tensor` `batch_shape`
"""
with ops.name_scope(self._name):
return array_ops.shape(self._ones(), name=name)
with ops.name_scope(self.name):
with ops.op_scope([], name):
return array_ops.shape(self._ones())
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
......@@ -154,8 +149,9 @@ class Normal(ContinuousDistribution):
Returns:
`Tensor` `event_shape`
"""
with ops.name_scope(self._name):
return constant_op.constant(1, name=name)
with ops.name_scope(self.name):
with ops.op_scope([], name):
return constant_op.constant([], dtype=dtypes.int32)
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
......@@ -169,17 +165,37 @@ class Normal(ContinuousDistribution):
@property
def mu(self):
"""Distribution parameter for the mean."""
return self._mu
@property
def sigma(self):
"""Distribution parameter for standard deviation."""
return self._sigma
@property
def mean(self):
return self._mu * array_ops.ones_like(self._sigma)
def log_pdf(self, x, name=None):
def mean(self, name="mean"):
"""Mean of this distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._sigma, self._mu], name):
return self._mu * array_ops.ones_like(self._sigma)
def mode(self, name="mode"):
"""Mode of this distribution."""
return self.mean(name="mode")
def std(self, name="std"):
"""Standard deviation of this distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._sigma, self._mu], name):
return self._sigma * array_ops.ones_like(self._mu)
def variance(self, name="variance"):
"""Variance of this distribution."""
with ops.name_scope(self.name):
with ops.op_scope([], name):
return math_ops.square(self.std())
def log_pdf(self, x, name="log_pdf"):
"""Log pdf of observations in `x` under these Normal distribution(s).
Args:
......@@ -189,16 +205,17 @@ class Normal(ContinuousDistribution):
Returns:
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
"""
with ops.op_scope([self._mu, self._sigma, x], name, "NormalLogPdf"):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s"
% (x.dtype, self.dtype))
log_2_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype)
return (-0.5*log_2_pi - math_ops.log(self._sigma)
-0.5*math_ops.square((x - self._mu) / self._sigma))
def cdf(self, x, name=None):
with ops.name_scope(self.name):
with ops.op_scope([self._mu, self._sigma, x], name):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s"
% (x.dtype, self.dtype))
log_2_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype)
return (-0.5*log_2_pi - math_ops.log(self._sigma)
-0.5*math_ops.square((x - self._mu) / self._sigma))
def cdf(self, x, name="cdf"):
"""CDF of observations in `x` under these Normal distribution(s).
Args:
......@@ -208,15 +225,16 @@ class Normal(ContinuousDistribution):
Returns:
cdf: tensor of dtype `dtype`, the CDFs of `x`.
"""
with ops.op_scope([self._mu, self._sigma, x], name, "NormalCdf"):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s"
% (x.dtype, self.dtype))
return (0.5 + 0.5*math_ops.erf(
1.0/(math.sqrt(2.0) * self._sigma)*(x - self._mu)))
def log_cdf(self, x, name=None):
with ops.name_scope(self.name):
with ops.op_scope([self._mu, self._sigma, x], name):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s"
% (x.dtype, self.dtype))
return (0.5 + 0.5*math_ops.erf(
1.0/(math.sqrt(2.0) * self._sigma)*(x - self._mu)))
def log_cdf(self, x, name="log_cdf"):
"""Log CDF of observations `x` under these Normal distribution(s).
Args:
......@@ -226,8 +244,9 @@ class Normal(ContinuousDistribution):
Returns:
log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`.
"""
with ops.op_scope([self._mu, self._sigma, x], name, "NormalLogCdf"):
return math_ops.log(self.cdf(x))
with ops.name_scope(self.name):
with ops.op_scope([self._mu, self._sigma, x], name):
return math_ops.log(self.cdf(x))
def pdf(self, x, name="pdf"):
"""The PDF of observations in `x` under these Normal distribution(s).
......@@ -241,7 +260,7 @@ class Normal(ContinuousDistribution):
"""
return super(Normal, self).pdf(x, name=name)
def entropy(self, name=None):
def entropy(self, name="entropy"):
"""The entropy of Normal distribution(s).
Args:
......@@ -250,14 +269,15 @@ class Normal(ContinuousDistribution):
Returns:
entropy: tensor of dtype `dtype`, the entropy.
"""
with ops.op_scope([self._mu, self._sigma], name, "NormalEntropy"):
two_pi_e1 = constant_op.constant(
2 * math.pi * math.exp(1), dtype=self.dtype)
# Use broadcasting rules to calculate the full broadcast sigma.
sigma = self._sigma * array_ops.ones_like(self._mu)
return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma))
def sample(self, n, seed=None, name=None):
with ops.name_scope(self.name):
with ops.op_scope([self._mu, self._sigma], name):
two_pi_e1 = constant_op.constant(
2 * math.pi * math.exp(1), dtype=self.dtype)
# Use broadcasting rules to calculate the full broadcast sigma.
sigma = self._sigma * array_ops.ones_like(self._mu)
return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma))
def sample(self, n, seed=None, name="sample"):
"""Sample `n` observations from the Normal Distributions.
Args:
......@@ -269,20 +289,21 @@ class Normal(ContinuousDistribution):
samples: `[n, ...]`, a `Tensor` of `n` samples for each
of the distributions determined by broadcasting the hyperparameters.
"""
with ops.op_scope([self._mu, self._sigma, n], name, "NormalSample"):
broadcast_shape = (self._mu + self._sigma).get_shape()
n = ops.convert_to_tensor(n)
shape = array_ops.concat(
0, [array_ops.pack([n]), array_ops.shape(self.mean)])
sampled = random_ops.random_normal(
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
# Provide some hints to shape inference
n_val = tensor_util.constant_value(n)
final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
sampled.set_shape(final_shape)
return sampled * self._sigma + self._mu
with ops.name_scope(self.name):
with ops.op_scope([self._mu, self._sigma, n], name):
broadcast_shape = (self._mu + self._sigma).get_shape()
n = ops.convert_to_tensor(n)
shape = array_ops.concat(
0, [array_ops.pack([n]), array_ops.shape(self.mean())])
sampled = random_ops.random_normal(
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
# Provide some hints to shape inference
n_val = tensor_util.constant_value(n)
final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
sampled.set_shape(final_shape)
return sampled * self._sigma + self._mu
@property
def is_reparameterized(self):
......
......@@ -22,7 +22,7 @@ import math
import numpy as np
from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
......@@ -35,7 +35,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
class StudentT(ContinuousDistribution):
class StudentT(distribution.ContinuousDistribution):
"""Student's t distribution with degree-of-freedom parameter df.
#### Mathematical details
......@@ -113,8 +113,8 @@ class StudentT(ContinuousDistribution):
contrib_tensor_util.assert_same_float_dtype(
(self._df, self._mu, self._sigma))
self._name = scope
self._batch_shape = self._ones().get_shape()
self._event_shape = tensor_shape.TensorShape([])
self._get_batch_shape = self._ones().get_shape()
self._get_event_shape = tensor_shape.TensorShape([])
@property
def name(self):
......@@ -139,31 +139,45 @@ class StudentT(ContinuousDistribution):
"""Scaling factors of these Student's t distribution(s)."""
return self._sigma
@property
def mean(self, name="mean"):
with ops.name_scope(self.name):
return math_ops.mul(self._mu, self._ones(), name=name)
with ops.op_scope([self._mu], name):
df_gt_1 = self._df > self._ones()
result_if_defined = self._mu * self._ones()
nan = np.nan + self._zeros()
return math_ops.select(df_gt_1, result_if_defined, nan)
@property
def variance(self, name="var"):
def mode(self, name="mode"):
with ops.name_scope(self.name):
with ops.op_scope([], name):
return array_ops.identity(self._mu)
def variance(self, name="variance"):
with ops.name_scope(self.name):
with ops.op_scope([self._df, self._sigma], name):
return math_ops.select(
(self._zeros() + self._df > 2),
self._zeros() + math_ops.square(self._sigma) * self._df /
(self._df - 2),
self._zeros() + np.nan)
def std(self, name="std"):
with ops.name_scope(self.name):
return math_ops.select(
(self._zeros() + self._df > 2),
self._zeros() + math_ops.square(self._sigma) * self._df /
(self._df - 2),
self._zeros() + np.inf,
name=name)
with ops.op_scope([], name):
return math_ops.sqrt(self.variance())
def batch_shape(self, name="batch_shape"):
with ops.name_scope(self.name):
return array_ops.shape(self._ones(), name=name)
with ops.op_scope([], name):
return array_ops.shape(self._ones())
def get_batch_shape(self):
return self._batch_shape
return self._get_batch_shape
def event_shape(self, name="event_shape"):
with ops.name_scope(self.name):
return constant_op.constant(1, name=name)
with ops.op_scope([], name):
return constant_op.constant([], dtype=math_ops.int32)
def get_event_shape(self):
return self._event_shape
......@@ -178,8 +192,8 @@ class StudentT(ContinuousDistribution):
Returns:
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
"""
with ops.op_scope([self._df, self._mu, self._sigma, x], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self._df, self._mu, self._sigma, x], name):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
......@@ -202,8 +216,8 @@ class StudentT(ContinuousDistribution):
Returns:
pdf: tensor of dtype `dtype`, the pdf values of `x`.
"""
with ops.op_scope([self._df, self._mu, self._sigma, x], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self._df, self._mu, self._sigma, x], name):
x = ops.convert_to_tensor(x)
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
......@@ -224,8 +238,8 @@ class StudentT(ContinuousDistribution):
Returns:
entropy: tensor of dtype `dtype`, the entropy.
"""
with ops.op_scope([self._df, self._sigma], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self._df, self._sigma], name):
u = array_ops.expand_dims(self._df + self._zeros(), -1)
v = array_ops.expand_dims(self._ones(), -1)
beta_arg = array_ops.concat(len(u.get_shape()) - 1, [u, v]) / 2
......@@ -247,8 +261,8 @@ class StudentT(ContinuousDistribution):
samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
with values of type `self.dtype`.
"""
with ops.op_scope([self._df, self._mu, self._sigma, n], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self._df, self._mu, self._sigma, n], name):
n = ops.convert_to_tensor(n, name="n")
n_val = tensor_util.constant_value(n)
......
......@@ -18,8 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
......@@ -30,7 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
class Uniform(ContinuousDistribution):
class Uniform(distribution.ContinuousDistribution):
"""Uniform distribution with `a` and `b` parameters.
The PDF of this distribution is constant between [`a`, `b`], and 0 elsewhere.
......@@ -70,11 +71,8 @@ class Uniform(ContinuousDistribution):
"""
with ops.op_scope([a, b], name):
with ops.control_dependencies([check_ops.assert_less(a, b)]):
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
if a.dtype != b.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
(a.dtype, b.dtype))
a = array_ops.identity(a, name="a")
b = array_ops.identity(b, name="b")
self._a = a
self._b = b
......@@ -94,14 +92,16 @@ class Uniform(ContinuousDistribution):
def batch_shape(self, name="batch_shape"):
with ops.name_scope(self.name):
return array_ops.shape(self._ones(), name=name)
with ops.op_scope([], name):
return array_ops.shape(self._ones())
def get_batch_shape(self):
return self._batch_shape
def event_shape(self, name="event_shape"):
with ops.name_scope(self.name):
return constant_op.constant(1, name=name)
with ops.op_scope([], name):
return constant_op.constant([], dtype=dtypes.int32)
def get_event_shape(self):
return self._event_shape
......@@ -125,8 +125,8 @@ class Uniform(ContinuousDistribution):
pdf: tensor of dtype `dtype`, the pdf values of `x`. If `x` is `nan`, will
return `nan`.
"""
with ops.op_scope([self.a, self.b, x], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self.a, self.b, x], name):
x = ops.convert_to_tensor(x, name="x")
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
......@@ -138,7 +138,7 @@ class Uniform(ContinuousDistribution):
math_ops.logical_or(broadcasted_x < self.a,
broadcasted_x > self.b),
array_ops.zeros_like(broadcasted_x),
(1.0 / self.range) * array_ops.ones_like(broadcasted_x)))
(1.0 / self.range()) * array_ops.ones_like(broadcasted_x)))
def log_pdf(self, x, name="log_pdf"):
return super(Uniform, self).log_pdf(x, name)
......@@ -154,24 +154,23 @@ class Uniform(ContinuousDistribution):
cdf: tensor of dtype `dtype`, the CDFs of `x`. If `x` is `nan`, will
return `nan`.
"""
with ops.op_scope([self.a, self.b, x], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self.a, self.b, x], name):
x = ops.convert_to_tensor(x, name="x")
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
(x.dtype, self.dtype))
broadcasted_x = x * self._ones()
return math_ops.select(broadcasted_x < self.a,
array_ops.zeros_like(broadcasted_x),
math_ops.select(broadcasted_x >= self.b,
array_ops.ones_like(broadcasted_x),
(broadcasted_x - self.a) /
self.range))
broadcasted_x = x * self._ones()
zeros = array_ops.zeros_like(x + self.a + self.b, dtype=self.dtype)
ones = array_ops.ones_like(x + self.a + self.b, dtype=self.dtype)
result_if_not_big = math_ops.select(
x < self.a, zeros, (broadcasted_x - self.a) / self.range())
return math_ops.select(x >= self.b, ones, result_if_not_big)
def log_cdf(self, x, name="log_cdf"):
with ops.op_scope([self.a, self.b, x], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self.a, self.b, x], name):
x = ops.convert_to_tensor(x, name="x")
return math_ops.log(self.cdf(x))
......@@ -184,9 +183,9 @@ class Uniform(ContinuousDistribution):
Returns:
entropy: tensor of dtype `dtype`, the entropy.
"""
with ops.op_scope([self.a, self.b], self.name):
with ops.name_scope(name):
return math_ops.log(self.range)
with ops.name_scope(self.name):
with ops.op_scope([self.a, self.b, self.range()], name):
return math_ops.log(self.range())
def sample(self, n, seed=None, name="sample"):
"""Sample `n` observations from the Uniform Distributions.
......@@ -200,8 +199,8 @@ class Uniform(ContinuousDistribution):
samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
with values of type `self.dtype`.
"""
with ops.op_scope([self.a, self.b, n], self.name):
with ops.name_scope(name):
with ops.name_scope(self.name):
with ops.op_scope([self.a, self.b, n], name):
n = ops.convert_to_tensor(n, name="n")
n_val = tensor_util.constant_value(n)
......@@ -216,20 +215,28 @@ class Uniform(ContinuousDistribution):
samples.set_shape(inferred_shape)
return (array_ops.expand_dims(self.a, 0) + array_ops.expand_dims(
self.range, 0) * samples)
self.range(), 0) * samples)
@property
def mean(self):
return (self.a + self.b) / 2
def mean(self, name="mean"):
with ops.name_scope(self.name):
with ops.op_scope([self._a, self._b], name):
return (self.a + self.b) / 2
@property
def variance(self):
return math_ops.square(self.range) / 12
def variance(self, name="variance"):
with ops.name_scope(self.name):
with ops.op_scope([self.range()], name):
return math_ops.square(self.range()) / 12.
@property
def range(self):
def std(self, name="std"):
with ops.name_scope(self.name):
with ops.op_scope([self.range()], name):
return self.range() / math_ops.sqrt(12.)
def range(self, name="range"):
"""`b - a`."""
return self.b - self.a
with ops.name_scope(self.name):
with ops.op_scope([self.a, self.b], name):
return self.b - self.a
@property
def is_reparameterized(self):
......
......@@ -22,6 +22,7 @@ cc_library(
name = "sparse_feature_cross_kernel",
srcs = ["sparse_feature_cross_kernel.cc"],
deps = [
"@farmhash_archive//:farmhash",
"@protobuf//:protobuf",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
......
......@@ -27,112 +27,34 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace {
// Seed is chosen based on third_party/tensorflow/core/lib/hash/hash.h
const int64 kInitialHashSeed = 0xDECAFCAFFE;
// Following functions are a copy of Hash64. It will be replaced by a
// fingerprint function.
// Original code: third_party/tensorflow/core/lib/hash/hash.h
static inline uint64 ByteAs64(char c) { return static_cast<uint64>(c) & 0xff; }
inline uint32 DecodeFixed32(const char* ptr) {
if (port::kLittleEndian) {
// Load the raw bytes
uint32 result;
memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
return result;
} else {
return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) |
(static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) |
(static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) |
(static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24));
}
}
inline uint64 DecodeFixed64(const char* ptr) {
if (port::kLittleEndian) {
// Load the raw bytes
uint64 result;
memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
return result;
} else {
uint64 lo = DecodeFixed32(ptr);
uint64 hi = DecodeFixed32(ptr + 4);
return (hi << 32) | lo;
}
}
uint64 LegacyHashFunction(const char* data, size_t n, uint64 seed) {
const uint64 m = 0xc6a4a7935bd1e995;
const int r = 47;
uint64 h = seed ^ (n * m);
while (n >= 8) {
uint64 k = DecodeFixed64(data);
data += 8;
n -= 8;
k *= m;
k ^= k >> r;
k *= m;
h ^= k;
h *= m;
}
switch (n) {
case 7:
h ^= ByteAs64(data[6]) << 48;
TF_FALLTHROUGH_INTENDED;
case 6:
h ^= ByteAs64(data[5]) << 40;
TF_FALLTHROUGH_INTENDED;
case 5:
h ^= ByteAs64(data[4]) << 32;
TF_FALLTHROUGH_INTENDED;
case 4:
h ^= ByteAs64(data[3]) << 24;
TF_FALLTHROUGH_INTENDED;
case 3:
h ^= ByteAs64(data[2]) << 16;
TF_FALLTHROUGH_INTENDED;
case 2:
h ^= ByteAs64(data[1]) << 8;
TF_FALLTHROUGH_INTENDED;
case 1:
h ^= ByteAs64(data[0]);
h *= m;
}
h ^= h >> r;
h *= m;
h ^= h >> r;
return h;
}
// An interface that represents a column with batches.
template <typename StringType>
template <typename InternalType>
class ColumnInterface {
public:
// Returns the number of features in the specified batch.
virtual int64 FeatureCount(int64 batch) const = 0;
// Returns the nth feature from the specified batch.
virtual StringType Feature(int64 batch, int64 n) const = 0;
// Returns the fingerprint of nth feature from the specified batch.
InternalType Feature(int64 batch, int64 n) const {
InternalType not_used = InternalType();
return DoFeature(batch, n, not_used);
}
virtual InternalType DoFeature(int64 batch, int64 n,
InternalType not_used) const = 0;
virtual ~ColumnInterface() {}
};
// A column that is backed by a sparse tensor.
template <typename StringType>
class SparseTensorColumn : public ColumnInterface<StringType> {
template <typename InternalType>
class SparseTensorColumn : public ColumnInterface<InternalType> {
public:
SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
std::vector<int64> feature_start_indices)
......@@ -146,11 +68,23 @@ class SparseTensorColumn : public ColumnInterface<StringType> {
return feature_counts_[batch];
}
StringType Feature(int64 batch, int64 n) const override {
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return Fingerprint64(values_.vec<string>().data()[start + n]);
return values_.vec<int64>().data()[start + n];
}
string DoFeature(int64 batch, int64 n, string not_used) const {
const int64 start = feature_start_indices_[batch];
if (DT_STRING == values_.dtype())
return values_.vec<string>().data()[start + n];
return StringType(std::to_string(values_.vec<int64>().data()[start + n]));
return std::to_string(values_.vec<int64>().data()[start + n]);
}
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
const int64 start = feature_start_indices_[batch];
return values_.vec<string>().data()[start + n];
}
~SparseTensorColumn() override {}
......@@ -162,16 +96,26 @@ class SparseTensorColumn : public ColumnInterface<StringType> {
};
// A column that is backed by a dense tensor.
template <typename StringType>
class DenseTensorColumn : public ColumnInterface<StringType> {
template <typename InternalType>
class DenseTensorColumn : public ColumnInterface<InternalType> {
public:
explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
StringType Feature(int64 batch, int64 n) const override {
int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
if (DT_STRING == tensor_.dtype())
return Fingerprint64(tensor_.matrix<string>()(batch, n));
return tensor_.matrix<int64>()(batch, n);
}
string DoFeature(int64 batch, int64 n, string not_used) const {
if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
return StringType(std::to_string(tensor_.matrix<int64>()(batch, n)));
return std::to_string(tensor_.matrix<int64>()(batch, n));
}
StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
return tensor_.matrix<string>()(batch, n);
}
~DenseTensorColumn() override {}
......@@ -209,19 +153,19 @@ class OutputUpdater {
};
// Generates the sparse crosses as concatenation of strings.
template <typename StringType>
template <typename InternalType>
class StringCrosser {
public:
StringCrosser(
const std::vector<std::unique_ptr<ColumnInterface<StringType>>>& columns,
const int64 not_used)
StringCrosser(const std::vector<
std::unique_ptr<ColumnInterface<InternalType>>>& columns,
const int64 not_used)
: columns_(columns) {}
string Generate(const int64 batch_index,
const std::vector<int>& permutation) const {
static const auto k_feature_separator = "_X_";
gtl::InlinedVector<StringType, 6> cross_vec(columns_.size());
gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
for (int i = 0; i < permutation.size(); i++) {
cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]);
}
......@@ -231,15 +175,21 @@ class StringCrosser {
}
private:
const std::vector<std::unique_ptr<ColumnInterface<StringType>>>& columns_;
const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
};
// Seed is chosen based on third_party/tensorflow/core/lib/hash/hash.h
const int64 kInitialHashSeed = 0xDECAFCAFFE;
int64 HashCombine(int64 a, int64 b) {
return a ^ (b + 0x9e3779b97f4a7800 + (a << 10) + (a >> 4));
}
// Generates the sparse crosses as nested hash to avoid string manipulations.
template <typename StringType>
class HashCrosser {
public:
HashCrosser(
const std::vector<std::unique_ptr<ColumnInterface<StringType>>>& columns,
const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
const int64 num_buckets)
: columns_(columns), num_buckets_(num_buckets) {}
......@@ -247,8 +197,8 @@ class HashCrosser {
const std::vector<int>& permutation) const {
uint64 hashed_output = kInitialHashSeed;
for (int i = 0; i < permutation.size(); i++) {
StringType str = columns_[i]->Feature(batch_index, permutation[i]);
hashed_output = LegacyHashFunction(str.data(), str.size(), hashed_output);
int64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
hashed_output = HashCombine(hashed_output, hash_i);
}
if (num_buckets_ > 0) {
return hashed_output % num_buckets_;
......@@ -259,16 +209,17 @@ class HashCrosser {
}
private:
const std::vector<std::unique_ptr<ColumnInterface<StringType>>>& columns_;
const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
const int64 num_buckets_;
};
// ProductIterator generates cartesian products based on indices.
template <typename StringType>
template <typename InternalType>
class ProductIterator {
public:
explicit ProductIterator(
const std::vector<std::unique_ptr<ColumnInterface<StringType>>>& columns,
const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
columns,
int64 batch_index)
: columns_(columns), batch_index_(batch_index) {
next_permutation_.resize(columns_.size(), 0);
......@@ -306,28 +257,28 @@ class ProductIterator {
private:
bool has_next_;
const std::vector<std::unique_ptr<ColumnInterface<StringType>>>& columns_;
const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
const int64 batch_index_;
std::vector<int> next_permutation_;
};
template <bool HASHED_OUTPUT, typename StringType>
template <bool HASHED_OUTPUT, typename InternalType>
struct CrossTraits;
template <typename StringType>
struct CrossTraits<true, StringType> {
typedef StringCrosser<StringType> Crosser;
template <typename InternalType>
struct CrossTraits<false, InternalType> {
typedef StringCrosser<InternalType> Crosser;
typedef OutputUpdater<string> Updater;
};
template <typename StringType>
struct CrossTraits<false, StringType> {
typedef HashCrosser<StringType> Crosser;
template <>
struct CrossTraits<true, int64> {
typedef HashCrosser Crosser;
typedef OutputUpdater<int64> Updater;
};
} // namespace
template <bool HASHED_OUTPUT, typename StringType>
template <bool HASHED_OUTPUT, typename InternalType>
class SparseFeatureCrossOp : public OpKernel {
public:
explicit SparseFeatureCrossOp(OpKernelConstruction* context)
......@@ -348,11 +299,11 @@ class SparseFeatureCrossOp : public OpKernel {
ValidateInput(context, indices_list_in, values_list_in, shapes_list_in,
dense_list_in);
std::vector<std::unique_ptr<ColumnInterface<StringType>>> columns =
std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
GenerateColumnsFromInput(indices_list_in, values_list_in,
shapes_list_in, dense_list_in);
typename CrossTraits<HASHED_OUTPUT, StringType>::Crosser crosser(
typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser(
columns, num_buckets_);
Tensor* indices_out;
Tensor* values_out;
......@@ -362,11 +313,11 @@ class SparseFeatureCrossOp : public OpKernel {
CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out,
&shape_out, &output_start_indices);
typename CrossTraits<HASHED_OUTPUT, StringType>::Updater updater(
typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
output_start_indices, indices_out, values_out);
auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) {
for (int b = begin; b < end; b++) {
ProductIterator<StringType> product_iterator(columns, b);
ProductIterator<InternalType> product_iterator(columns, b);
int64 cross_count = 0;
while (product_iterator.HasNext()) {
const auto permutation = product_iterator.Next();
......@@ -479,12 +430,12 @@ class SparseFeatureCrossOp : public OpKernel {
}
// Generate the columns given the sparse and dense inputs.
std::vector<std::unique_ptr<ColumnInterface<StringType>>>
std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateColumnsFromInput(const OpInputList& indices_list_in,
const OpInputList& values_list_in,
const OpInputList& shapes_list_in,
const OpInputList& dense_list_in) {
std::vector<std::unique_ptr<ColumnInterface<StringType>>> columns;
std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
const int64 number_of_columns = shapes_list_in.size();
......@@ -497,12 +448,13 @@ class SparseFeatureCrossOp : public OpKernel {
&feature_start_indices);
for (int i = 0; i < values_list_in.size(); ++i) {
columns.emplace_back(new SparseTensorColumn<StringType>(
columns.emplace_back(new SparseTensorColumn<InternalType>(
values_list_in[i], std::move(feature_counts[i]),
std::move(feature_start_indices[i])));
}
for (int i = 0; i < dense_list_in.size(); ++i) {
columns.emplace_back(new DenseTensorColumn<StringType>(dense_list_in[i]));
columns.emplace_back(
new DenseTensorColumn<InternalType>(dense_list_in[i]));
}
return columns;
......@@ -536,7 +488,8 @@ class SparseFeatureCrossOp : public OpKernel {
// It also output_start_indices which contains the start indices for each
// input in the output SparseTensor.
void CreateOutputTensors(
const std::vector<std::unique_ptr<ColumnInterface<StringType>>>& columns,
const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
columns,
int64 batch_size, OpKernelContext* context, Tensor** indices_out,
Tensor** values_out, Tensor** shape_out,
std::vector<int64>* output_start_indices) {
......@@ -569,7 +522,8 @@ class SparseFeatureCrossOp : public OpKernel {
// Returns number of crosses for a given batch_index
int64 CrossCountByBatchIndex(
const std::vector<std::unique_ptr<ColumnInterface<StringType>>>& columns,
const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
columns,
int batch_index) {
int64 cross_count = 1;
for (int i = 0; i < columns.size(); i++) {
......@@ -589,24 +543,24 @@ REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
.Device(DEVICE_CPU)
.TypeConstraint<string>("out_type")
.TypeConstraint<string>("internal_type"),
SparseFeatureCrossOp<true, StringPiece>);
SparseFeatureCrossOp<false, StringPiece>);
REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
.Device(DEVICE_CPU)
.TypeConstraint<string>("out_type")
.TypeConstraint<int64>("internal_type"),
SparseFeatureCrossOp<true, string>);
SparseFeatureCrossOp<false, string>);
REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
.Device(DEVICE_CPU)
.TypeConstraint<int64>("out_type")
.TypeConstraint<string>("internal_type"),
SparseFeatureCrossOp<false, StringPiece>);
SparseFeatureCrossOp<true, int64>);
REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
.Device(DEVICE_CPU)
.TypeConstraint<int64>("out_type")
.TypeConstraint<int64>("internal_type"),
SparseFeatureCrossOp<false, string>);
SparseFeatureCrossOp<true, int64>);
} // namespace tensorflow
......@@ -293,10 +293,8 @@ class SparseCrossOpTest(tf.test.TestCase):
])
],
hashed_output=True)
# Hash64("batch1-FC3-F1",
# Hash64("batch1-FC2-F1",
# Hash64("batch1-FC1-F1"))) = 571927800417497063
expected_out = self._sparse_tensor([[571927800417497063]])
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[3735511728867393167]])
with self.test_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
......@@ -316,11 +314,8 @@ class SparseCrossOpTest(tf.test.TestCase):
],
hashed_output=True,
num_buckets=100)
# Hash64("batch1-FC3-F1",
# Hash64("batch1-FC2-F1",
# Hash64("batch1-FC1-F1"))) = 571927800417497063
# 571927800417497063 % 100 = 63
expected_out = self._sparse_tensor([[63]])
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[74]])
with self.test_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
......
......@@ -564,18 +564,20 @@ class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
return "{}".format(self)
def insert_transformed_feature(self, columns_to_tensors):
# No transformation is needed for _RealValuedColumn.
columns_to_tensors[self] = columns_to_tensors[self.name]
# No transformation is needed for _RealValuedColumn except reshaping.
input_tensor = columns_to_tensors[self.name]
batch_size = input_tensor.get_shape().as_list()[0]
batch_size = int(batch_size) if batch_size else -1
flattened_shape = [batch_size, self.dimension]
columns_to_tensors[self] = array_ops.reshape(
math_ops.to_float(input_tensor), flattened_shape)
# pylint: disable=unused-argument
def to_dnn_input_layer(self,
input_tensor,
weight_collections=None,
trainable=True):
batch_size = input_tensor.get_shape().as_list()[0]
batch_size = int(batch_size) if batch_size else -1
flattened_shape = [batch_size, self.dimension]
return array_ops.reshape(math_ops.to_float(input_tensor), flattened_shape)
return input_tensor
def to_weighted_sum(self,
input_tensor,
......@@ -740,8 +742,10 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
def insert_transformed_feature(self, columns_to_tensors):
# Bucketize the source column.
if self.source_column not in columns_to_tensors:
self.source_column.insert_transformed_feature(columns_to_tensors)
columns_to_tensors[self] = bucketization_op.bucketize(
columns_to_tensors[self.source_column.name],
columns_to_tensors[self.source_column],
boundaries=list(self.boundaries))
# pylint: disable=unused-argument
......
......@@ -94,14 +94,17 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
weight_column_name=None,
optimizer=None,
activation_fn=nn.relu,
dropout=None):
super(DNNClassifier, self).__init__(n_classes=n_classes,
dropout=None,
config=None):
super(DNNClassifier, self).__init__(model_dir=model_dir,
n_classes=n_classes,
weight_column_name=weight_column_name,
dnn_feature_columns=feature_columns,
dnn_optimizer=optimizer,
dnn_hidden_units=hidden_units,
dnn_activation_fn=activation_fn,
dnn_dropout=dropout)
dnn_dropout=dropout,
config=config)
def _get_train_ops(self, features, targets):
"""See base class."""
......@@ -185,13 +188,16 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
weight_column_name=None,
optimizer=None,
activation_fn=nn.relu,
dropout=None):
super(DNNRegressor, self).__init__(weight_column_name=weight_column_name,
dropout=None,
config=None):
super(DNNRegressor, self).__init__(model_dir=model_dir,
weight_column_name=weight_column_name,
dnn_feature_columns=feature_columns,
dnn_optimizer=optimizer,
dnn_hidden_units=hidden_units,
dnn_activation_fn=activation_fn,
dnn_dropout=dropout)
dnn_dropout=dropout,
config=config)
def _get_train_ops(self, features, targets):
"""See base class."""
......
......@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import abc
import inspect
import os
import tempfile
import time
......@@ -74,6 +75,19 @@ def _get_predict_input_fn(x, y, batch_size):
return df.input_builder, df.get_feed_dict_fn()
def _get_arguments(func):
"""Returns list of arguments this function has."""
if hasattr(func, '__code__'):
# Regular function.
return inspect.getargspec(func).args
elif hasattr(func, '__call__'):
# Callable object.
return _get_arguments(func.__call__)
elif hasattr(func, 'func'):
# Partial function.
return _get_arguments(func.func)
class BaseEstimator(sklearn.BaseEstimator):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
......@@ -589,17 +603,58 @@ class Estimator(BaseEstimator):
Parameters:
model_fn: Model function, takes features and targets tensors or dicts of
tensors and returns predictions and loss tensors.
E.g. `(features, targets) -> (predictions, loss, train_op)`.
Supports next three signatures for the function:
* `(features, targets) -> (predictions, loss, train_op)`
* `(features, targets, mode) -> (predictions, loss, train_op)`
* `(features, targets, mode, params) ->
(predictions, loss, train_op)`
Where:
* `features` are single `Tensor` or `dict` of `Tensor`s
(depending on data passed to `fit`),
* `targets` are `Tensor` or
`dict` of `Tensor`s (for multi-head model).
* `mode` represents if this training, evaluation or prediction.
See `ModeKeys` for example keys.
* `params` is a `dict` of hyperparameters. Will receive what is
passed to Estimator in `params` parameter. This allows to
configure Estimators from hyper parameter tunning.
model_dir: Directory to save model parameters, graph and etc.
config: Configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
"""
def __init__(self,
model_fn=None,
model_dir=None,
config=None):
config=None,
params=None):
super(Estimator, self).__init__(model_dir=model_dir, config=config)
if model_fn is not None:
# Check number of arguments of the given function matches requirements.
model_fn_args = _get_arguments(model_fn)
if params is not None and 'params' not in model_fn_args:
raise ValueError('Estimator\'s model_fn (%s) has less then 4 '
'arguments, but not None params (%s) are passed.' %
(model_fn, params))
if params is None and 'params' in model_fn_args:
logging.warning('Estimator\'s model_fn (%s) has includes params '
'argument, but params are not passed to Estimator.' %
model_fn)
self._model_fn = model_fn
self.params = params
def _call_model_fn(self, features, targets, mode):
"""Calls model function with support of 2, 3 or 4 arguments."""
model_fn_args = _get_arguments(self._model_fn)
if 'mode' in model_fn_args:
if 'params' in model_fn_args:
return self._model_fn(
features, targets, mode=mode, params=self.params)
else:
return self._model_fn(
features, targets, mode=mode)
return self._model_fn(features, targets)
def _get_train_ops(self, features, targets):
"""Method that builds model graph and returns trainer ops.
......@@ -615,7 +670,7 @@ class Estimator(BaseEstimator):
Returns:
Tuple of train `Operation` and loss `Tensor`.
"""
_, loss, train_op = self._model_fn(features, targets, ModeKeys.TRAIN)
_, loss, train_op = self._call_model_fn(features, targets, ModeKeys.TRAIN)
return train_op, loss
def _get_eval_ops(self, features, targets, metrics):
......@@ -633,7 +688,7 @@ class Estimator(BaseEstimator):
Returns:
metrics: `dict` of `Tensor` objects.
"""
predictions, loss, _ = self._model_fn(features, targets, ModeKeys.EVAL)
predictions, loss, _ = self._call_model_fn(features, targets, ModeKeys.EVAL)
result = {'loss': loss}
metrics = metrics or {}
if isinstance(targets, dict) and len(targets) == 1:
......@@ -663,5 +718,5 @@ class Estimator(BaseEstimator):
"""
targets = tensor_signature.create_placeholders_from_signatures(
self._targets_info)
predictions, _, _ = self._model_fn(features, targets, ModeKeys.INFER)
predictions, _, _ = self._call_model_fn(features, targets, ModeKeys.INFER)
return predictions
......@@ -61,7 +61,19 @@ def boston_eval_fn():
return tf.concat(0, [features, features]), tf.concat(0, [target, target])
def linear_model_fn(features, target, unused_mode):
def linear_model_params_fn(features, target, mode, params):
assert mode in ('train', 'eval', 'infer')
prediction, loss = (
tf.contrib.learn.models.linear_regression_zero_init(features, target)
)
train_op = tf.contrib.layers.optimize_loss(
loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
learning_rate=params['learning_rate'])
return prediction, loss, train_op
def linear_model_fn(features, target, mode):
assert mode in ('train', 'eval', 'infer')
prediction, loss = (
tf.contrib.learn.models.linear_regression_zero_init(features, target)
)
......@@ -71,7 +83,7 @@ def linear_model_fn(features, target, unused_mode):
return prediction, loss, train_op
def logistic_model_fn(features, target, unused_mode):
def logistic_model_no_mode_fn(features, target):
target = tf.one_hot(target, 3, 1, 0)
prediction, loss = (
tf.contrib.learn.models.logistic_regression_zero_init(features, target)
......@@ -85,19 +97,26 @@ def logistic_model_fn(features, target, unused_mode):
class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor):
def __init__(self):
self.calls = None
self.begin_calls = None
self.end_calls = None
self.expect_calls = None
def begin(self, max_steps):
self.calls = 0
self.begin_calls = 0
self.end_calls = 0
self.expect_calls = max_steps
def step_begin(self, step):
self.begin_calls += 1
return {}
def step_end(self, step, outputs):
self.calls += 1
self.end_calls += 1
return False
def end(self):
assert self.calls == self.expect_calls
assert (self.end_calls == self.expect_calls and
self.begin_calls == self.expect_calls)
class EstimatorTest(tf.test.TestCase):
......@@ -146,6 +165,12 @@ class EstimatorTest(tf.test.TestCase):
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
self.assertLess(scores3['MSE'], scores['MSE'])
def testEstimatorParams(self):
boston = tf.contrib.learn.datasets.load_boston()
est = tf.contrib.learn.Estimator(model_fn=linear_model_params_fn,
params={'learning_rate': 0.01})
est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100)
def testBostonAll(self):
boston = tf.contrib.learn.datasets.load_boston()
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
......@@ -160,7 +185,7 @@ class EstimatorTest(tf.test.TestCase):
def testIrisAll(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.Estimator(model_fn=logistic_model_fn)
est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
est.fit(iris.data, iris.target, steps=100)
scores = est.evaluate(
x=iris.data,
......@@ -177,7 +202,7 @@ class EstimatorTest(tf.test.TestCase):
def testIrisInputFn(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.Estimator(model_fn=logistic_model_fn)
est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
est.fit(input_fn=iris_input_fn, steps=100)
_ = est.evaluate(input_fn=iris_input_fn, steps=1)
predictions = est.predict(x=iris.data)['class']
......
......@@ -80,13 +80,15 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
model_dir=None,
n_classes=2,
weight_column_name=None,
optimizer=None):
optimizer=None,
config=None):
super(LinearClassifier, self).__init__(
model_dir=model_dir,
n_classes=n_classes,
weight_column_name=weight_column_name,
linear_feature_columns=feature_columns,
linear_optimizer=optimizer)
linear_optimizer=optimizer,
config=config)
def _get_train_ops(self, features, targets):
"""See base class."""
......@@ -156,12 +158,14 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
model_dir=None,
n_classes=2,
weight_column_name=None,
optimizer=None):
optimizer=None,
config=None):
super(LinearRegressor, self).__init__(
model_dir=model_dir,
weight_column_name=weight_column_name,
linear_feature_columns=feature_columns,
linear_optimizer=optimizer)
linear_optimizer=optimizer,
config=config)
def _get_train_ops(self, features, targets):
"""See base class."""
......
......@@ -105,7 +105,8 @@ def _restore_from_checkpoint(session, graph, checkpoint_path, saver=None):
def _run_with_monitors(session, step, tensors, feed_dict, monitors):
"""Runs session for given tensors with monitor callbacks."""
for monitor in monitors:
tensors = monitor.step_begin(step, tensors)
tensors += monitor.step_begin(step)
tensors = list(set(tensors))
outputs = session.run(tensors, feed_dict=feed_dict)
outputs = dict(zip(
......
......@@ -50,7 +50,7 @@ class BaseMonitor(object):
def epoch_end(self, epoch):
pass
def step_begin(self, step, tensors): # pylint: disable=unused-argument
def step_begin(self, step): # pylint: disable=unused-argument
"""Callback before training step begins.
Use this callback to:
......@@ -58,12 +58,11 @@ class BaseMonitor(object):
Args:
step: int, global step of the model.
tensors: list of `Tensors` that going to be passed to session.run.
Returns:
Dict of `Tensors` that going to be ran.
List of `Tensors` that going to be ran.
"""
return tensors
return []
def step_end(self, step, output): # pylint: disable=unused-argument
"""Callback after training step finished.
......@@ -99,26 +98,30 @@ class EveryN(BaseMonitor):
self._every_n_steps = every_n_steps
self._first_n_steps = first_n_steps
self._max_steps = None
self._last_step = 0
def begin(self, max_steps=None):
self._max_steps = max_steps
def every_n_step_begin(self, step, tensors): # pylint: disable=unused-argument
return tensors
def every_n_step_begin(self, step): # pylint: disable=unused-argument
return []
def every_n_step_end(self, step, outputs): # pylint: disable=unused-argument
return False
def step_begin(self, step, tensors):
if (step <= self._first_n_steps or step % self._every_n_steps == 0 or
def step_begin(self, step):
if (step <= self._first_n_steps or
step >= (self._every_n_steps + self._last_step) or
step == self._max_steps):
tensors = self.every_n_step_begin(step, tensors)
return tensors
return self.every_n_step_begin(step)
return []
def step_end(self, step, output):
to_stop = False
if (step <= self._first_n_steps or step % self._every_n_steps == 0 or
if (step <= self._first_n_steps or
step >= (self._every_n_steps + self._last_step) or
step == self._max_steps):
self._last_step = step
to_stop = self.every_n_step_end(step, output)
return to_stop
......@@ -135,8 +138,8 @@ class PrintTensor(EveryN):
super(PrintTensor, self).__init__(every_n, first_n)
self._tensor_names = tensor_names
def every_n_step_begin(self, unused_step, tensors):
return tensors + self._tensor_names
def every_n_step_begin(self, unused_step):
return self._tensor_names
def every_n_step_end(self, step, outputs):
stats = []
......@@ -162,8 +165,8 @@ class SummarySaver(EveryN):
super(SummarySaver, self).set_estimator(estimator)
self._summary_writer = summary_io.SummaryWriter(self._estimator.model_dir)
def every_n_step_begin(self, unused_step, tensors):
return tensors + [self._summary_op]
def every_n_step_begin(self, unused_step):
return [self._summary_op]
def every_n_step_end(self, step, outputs):
summary_strs = outputs[self._summary_op.name]
......@@ -225,8 +228,8 @@ class CaptureVariable(EveryN):
self.var_name = var_name
self.var_values = []
def every_n_step_begin(self, unused_step, tensors):
return tensors + [self.var_name]
def every_n_step_begin(self, unused_step):
return [self.var_name]
def every_n_step_end(self, step, outputs):
self.var_values.append(outputs[self.var_name])
......
......@@ -65,6 +65,7 @@ class NonLinearTest(tf.test.TestCase):
self.assertEqual(len(biases), 5)
def testDNNDropout0(self):
random.seed(42)
# Dropout prob == 0.
iris = tf.contrib.learn.datasets.load_iris()
classifier = tf.contrib.learn.TensorFlowDNNClassifier(
......@@ -74,6 +75,7 @@ class NonLinearTest(tf.test.TestCase):
self.assertGreater(score, 0.9, "Failed with score = {0}".format(score))
def testDNNDropout0_1(self):
random.seed(42)
# Dropping only a little.
tf.set_random_seed(42)
iris = tf.contrib.learn.datasets.load_iris()
......@@ -85,6 +87,7 @@ class NonLinearTest(tf.test.TestCase):
self.assertGreater(score, 0.9, "Failed with score = {0}".format(score))
def testDNNDropout0_9(self):
random.seed(42)
# Dropping out most of it.
iris = tf.contrib.learn.datasets.load_iris()
classifier = tf.contrib.learn.TensorFlowDNNClassifier(
......@@ -93,7 +96,7 @@ class NonLinearTest(tf.test.TestCase):
score = accuracy_score(iris.target, classifier.predict(iris.data))
self.assertGreater(score, 0.3, "Failed with score = {0}".format(score))
# If the quality is higher - dropout is not working.
self.assertLess(score, 0.5, "Failed with score = {0}".format(score))
self.assertLess(score, 0.6, "Failed with score = {0}".format(score))
def testRNN(self):
random.seed(42)
......
# Description: Tensorflow Serving session bundle.
package(
default_visibility = ["//visibility:public"],
features = [
"-layering_check",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
"g3doc/sitemap.md",
],
),
)
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
py_library(
name = "exporter",
srcs = ["exporter.py"],
srcs_version = "PY2AND3",
deps = [
":gc",
":manifest_proto_py",
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "exporter_test",
size = "small",
srcs = [
"exporter_test.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:private"],
deps = [
":exporter",
":gc",
":manifest_proto_py",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "gc",
srcs = ["gc.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "gc_test",
srcs = [
"gc_test.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:private"],
deps = [
":gc",
"//tensorflow:tensorflow_py",
],
)
cc_library(
name = "session_bundle",
srcs = ["session_bundle.cc"],
hdrs = ["session_bundle.h"],
deps = [
":manifest_proto_cc",
":signature",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow_opensource",
],
)
cc_test(
name = "session_bundle_test",
size = "small",
srcs = ["session_bundle_test.cc"],
data = [
"//tensorflow/contrib/session_bundle/example:half_plus_two",
],
# Link in all registered kernels.
linkstatic = 1,
visibility = ["//visibility:private"],
deps = [
":session_bundle",
":test_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "signature",
srcs = ["signature.cc"],
hdrs = ["signature.h"],
deps = [
":manifest_proto_cc",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow_opensource",
],
)
cc_test(
name = "signature_test",
size = "small",
srcs = ["signature_test.cc"],
visibility = ["//visibility:private"],
deps = [
":manifest_proto_cc",
":signature",
":test_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.cc"],
hdrs = ["test_util.h"],
visibility = ["//visibility:private"],
deps = [
"//tensorflow/core",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
tf_proto_library(
name = "manifest_proto",
srcs = ["manifest.proto"],
cc_api_version = 2,
py_api_version = 2,
visibility = ["//visibility:public"],
)
# TensorFlow Inference Model Format
[TOC]
## Overview
This document describes the data formats and layouts for exporting [TensorFlow]
(https://www.tensorflow.org/) models for inference.
These exports have the following properties,
* Recoverable
* given an export the graph can easily be initialized and run
* Hermetic
* an export directory is self-contained to facilitate distribution
## Directory Structure
~~~
# Directory overview
00000000/
assets/
export.meta
export-?????-of-?????
~~~
* `00000000` Export version
* Format `%08d`
* `assets` Asset file directory
* Holds auxiliary files for the graph (e.g., vocabularies)
* `export.meta` MetaGraph Definition
* Binary [`tensorflow::MetaGraphDef`]
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/protobuf/meta_graph.proto)
* `export-?????-of-?????`
* Graph Variables
* Outputs from Python [`Saver`]
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/training/saver.py)
with `sharded=True`.
## Python exporting code
The [`Exporter`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/exporter.py)
class can be used to export a model in the above format from a Tensorflow python
binary.
## C++ initialization code
The [`LoadSessionBundleFromPath`]
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/session_bundle.h)
function can be used to create a `tensorflow::Session` and initialize it from an
export. This function takes options and the path to the export and returns a
bundle of export data including a `tensorflow::Session` which can be run.
## Signatures
Graphs used for inference tasks typically have set of inputs and outputs used
at inference time. We call this a signature.
### Standard Signatures (standard usage)
Graphs used for standard inference tasks have standard set of inputs and
outputs. For example, a graph used for a regression task has an input tensor for
the data and an output tensor for the regression values. The signature mechanism
makes it easy to identify the relevant input and output tensors for common graph
applications.
The [`Manifest`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/manifest.proto)
contains a `Signature` message which contains the task specific inputs and
outputs.
~~~
// A Signature specifies the inputs and outputs of commonly used graphs.
message Signature {
oneof type {
RegressionSignature regression_signature = 1;
ClassificationSignature classification_signature = 2;
GenericSignature generic_signature = 3;
}
};
~~~
Standard signature can be set at export time using the `Exporter` API
~~~python
# Run an export.
signature = exporter.classification_signature(input_tensor=input,
classes_tensor=output)
export = exporter.Exporter(saver)
export.init(sess.graph.as_graph_def(),
default_graph_signature=signature)
export.export(export_path,
global_step_tensor,
sess)
~~~
These can be recovered at serving time using utilities in [`signature.h`]
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/signature.h)
~~~c++
// Get the a classification signature.
ClassificationSignature signature;
TF_CHECK_OK(GetClassificationSignature(bundle->meta_graph_def, &signature));
// Run the graph.
Tensor input_tensor = GetInputTensor();
Tensor classes_tensor;
Tensor scores_tensor;
TF_CHECK_OK(RunClassification(signature, input_tensor, session,
&classes_tensor, &scores_tensor));
~~~
### Generic Signatures (custom or advanced usage)
Generic Signatures enable fully custom usage of the `tensorflow::Session` API.
They are recommended for when the standard Signatures do not satisfy a
particular use-case. A general example of when to use these is for a model
taking a single input and generating multiple outputs performing different
inferences.
~~~
// GenericSignature specifies a map from logical name to Tensor name.
// Typical application of GenericSignature is to use a single GenericSignature
// that includes all of the Tensor nodes and target names that may be useful at
// serving, analysis or debugging time. The recommended name for this signature
// is "generic_bindings".
message GenericSignature {
map<string, TensorBinding> map = 1;
};
~~~
Generic Signatures can be used to compliment a standard signature, for example
to support debugging. Here is an example usage including both the standard
regression signature and a generic signature.
~~~python
named_tensor_bindings = {"logical_input_A": v0,
"logical_input_B": v1}
signatures = {
"regression": exporter.regression_signature(input_tensor=v0,
output_tensor=v1),
"generic": exporter.generic_signature(named_tensor_bindings)}
export = exporter.Exporter(saver)
export.init(sess.graph.as_graph_def(),
named_graph_signature=signatures)
export.export(export_path,
global_step_tensor,
sess)
~~~
Generic Signature does not differentiate between input and output tensors. It
provides full flexibility to specify the input & output tensors you need.
The benefit is preserving a mapping between names that you specify at export
time (we call these the logical names), and the actual graph node names that may
be less stable and/or auto-generated by TensorFlow.
In [`signature.h`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/signature.h),
note that the generic signature methods BindGenericInputs and BindGenericNames
are doing simple string to string mapping as a convenience. These methods map
from the names used at training time to actual names in the graph. Use the bound
results from those methods, e.g. `vector<pair<string, Tensor>>` and
`vector<string>` respectively, as inputs to`tensorflow::Session->Run()`. For
`Session->Run()`, map these into the first two parameters, `inputs` and
`output_tensor_names` respectively. The next param, `target_node_names` is
typically null at inference time. The last param outputs is for the results in
the same order of your `output_tensor_names`.
## Initialization
Some graphs many require custom initialization after the variables have been
restored. Such initialization, done through an arbitrary Op, can be added using
the `Exporter` API. If set, `LoadSessionBundleFromPath` will automatically run
the Op when restoring a `Session` following the loading of variables.
## Assets
In many cases we have Ops which depend on external files for initialization
(such as vocabularies). These "assets" are not stored in the graph and are
needed for both training and inference.
In order to create hermetic exports these asset files need to be 1) copied to
each export directory and 2) read when recovering a session from an export base
directory.
Copying assets to the export dir is handled with a callback mechanism.
The callback function receives two parameters 1) the dictionary of source files
to desired basename and 2) the export directory. The default callback uses
`gfile.Copy` to perform the copy.
The tensors that contains the filepath to be copied and be replaced for
inference in specified by passing the collection of asset filepath tensor,
which is usually extracted from the graph by `tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)`.
~~~python
# Run an export.
export = exporter.Exporter(save)
export.init(
sess.graph.as_graph_def(),
asset_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS))
export.export(export_path, global_step_tensor, sess)
~~~
Users can use their own callbacks as shown in the following example, with the
requirement to keep the basename of the original files:
~~~python
def my_custom_copy_callback(files_to_copy, export_dir_path):
# Copy all source files (keys) in files_to_copy to export_dir_path
# using the corresponging basename (value).
...
# Run an export.
export = exporter.Exporter(save)
export.init(
sess.graph.as_graph_def(),
asset_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
asset_callback=my_custom_copy_callback)
export.export(export_path, global_step_tensor, sess)
~~~
`AssetFile` binds the name of a tensor in the graph to the name of a file
within the assets directory. `LoadSessionBundleFromPath` will handle the base
path and asset directory swap/concatenation such that the tensor is set with
the fully qualified filename upon return.
# Notes of exporter usage
The typical workflow of model exporting is:
1. Build model graph G
2. Train variables or load trained variables from checkpoint in session S
3. [Optional] build inference graph I
4. Export G
The Exporter should be used as follows:
1. The Saver used in Exporter(saver) should be created under the context of G
2. Exporter.init() should be called under the context of G
3. Exporter.export() should be called using session S
4. If I is provided for Exporter.init(), an exact same Saver should be created
under I as the saver under G -- in the way that exact same Save/Restore ops
are created in both G and S
# Description: Tensorflow Serving session_bundle example.
package(
default_visibility = ["//tensorflow/contrib/session_bundle:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
# vardef("PYTHON_BIN_PATH", "/usr/bin/python")
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
"g3doc/sitemap.md",
],
),
visibility = ["//visibility:public"],
)
py_binary(
name = "export_half_plus_two",
srcs = [
"export_half_plus_two.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/session_bundle:exporter",
],
)
genrule(
name = "half_plus_two",
outs = [
"half_plus_two/00000123/export.meta",
"half_plus_two/00000123/export-00000-of-00001",
],
cmd =
"rm -rf /tmp/half_plus_two; " +
"$(PYTHON_BIN_PATH) $(locations :export_half_plus_two); " +
"cp -r /tmp/half_plus_two/* $(@D)/half_plus_two",
tools = [
":export_half_plus_two",
],
visibility = ["//visibility:public"],
)
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exports a toy linear regression inference graph.
Exports a TensorFlow graph to /tmp/half_plus_two/ based on the Exporter
format, go/tf-exporter.
This graph calculates,
y = a*x + b
where a and b are variables with a=0.5 and b=2.
Output from this program is typically used to exercise Session
loading and execution code.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter
def Export():
export_path = "/tmp/half_plus_two"
with tf.Session() as sess:
# Make model parameters a&b variables instead of constants to
# exercise the variable reloading mechanisms.
a = tf.Variable(0.5, name="a")
b = tf.Variable(2.0, name="b")
# Calculate, y = a*x + b
# here we use a placeholder 'x' which is fed at inference time.
x = tf.placeholder(tf.float32, name="x")
y = tf.add(tf.mul(a, x), b, name="y")
# Setup a standard Saver for our variables.
save = tf.train.Saver({"a": a, "b": b}, sharded=True)
# asset_path contains the base directory of assets used in training (e.g.
# vocabulary files).
original_asset_path = tf.constant("/tmp/original/export/assets")
# Ops reading asset files should reference the asset_path tensor
# which stores the original asset path at training time and the
# overridden assets directory at restore time.
asset_path = tf.Variable(original_asset_path,
name="asset_path",
trainable=False,
collections=[])
assign_asset_path = asset_path.assign(original_asset_path)
# Use a fixed global step number.
global_step_tensor = tf.Variable(123, name="global_step")
# Create a RegressionSignature for our input and output.
signature = exporter.regression_signature(input_tensor=x, output_tensor=y)
# Create two filename assets and corresponding tensors.
# TODO(b/26254158) Consider adding validation of file existance as well as
# hashes (e.g. sha1) for consistency.
original_filename1 = tf.constant("hello1.txt")
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, original_filename1)
filename1 = tf.Variable(original_filename1,
name="filename1",
trainable=False,
collections=[])
assign_filename1 = filename1.assign(original_filename1)
original_filename2 = tf.constant("hello2.txt")
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, original_filename2)
filename2 = tf.Variable(original_filename2,
name="filename2",
trainable=False,
collections=[])
assign_filename2 = filename2.assign(original_filename2)
# Init op contains a group of all variables that we assign.
init_op = tf.group(assign_asset_path, assign_filename1, assign_filename2)
# CopyAssets is used as a callback during export to copy files to the
# given export directory.
def CopyAssets(filepaths, export_path):
print("copying asset files to: %s" % export_path)
for filepath in filepaths:
print("copying asset file: %s" % filepath)
# Run an export.
tf.initialize_all_variables().run()
export = exporter.Exporter(save)
export.init(
sess.graph.as_graph_def(),
init_op=init_op,
default_graph_signature=signature,
assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
assets_callback=CopyAssets)
export.export(export_path, global_step_tensor, sess)
def main(_):
Export()
if __name__ == "__main__":
tf.app.run()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export a TensorFlow model.
See: go/tf-exporter
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import six
import tensorflow as tf
from google.protobuf.any_pb2 import Any
from tensorflow.contrib.session_bundle import gc
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
from tensorflow.python.util import compat
# See: go/tf-exporter for these constants and directory structure.
VERSION_FORMAT_SPECIFIER = "%08d"
ASSETS_DIRECTORY = "assets"
EXPORT_BASE_NAME = "export"
EXPORT_SUFFIX_NAME = "meta"
META_GRAPH_DEF_FILENAME = EXPORT_BASE_NAME + "." + EXPORT_SUFFIX_NAME
VARIABLES_FILENAME = EXPORT_BASE_NAME
VARIABLES_FILENAME_PATTERN = VARIABLES_FILENAME + "-?????-of-?????"
INIT_OP_KEY = "serving_init_op"
SIGNATURES_KEY = "serving_signatures"
ASSETS_KEY = "serving_assets"
GRAPH_KEY = "serving_graph"
def gfile_copy_callback(files_to_copy, export_dir_path):
"""Callback to copy files using `gfile.Copy` to an export directory.
This method is used as the default `assets_callback` in `Exporter.init` to
copy assets from the `assets_collection`. It can also be invoked directly to
copy additional supplementary files into the export directory (in which case
it is not a callback).
Args:
files_to_copy: A dictionary that maps original file paths to desired
basename in the export directory.
export_dir_path: Directory to copy the files to.
"""
tf.logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
gfile.MakeDirs(export_dir_path)
for source_filepath, basename in files_to_copy.items():
new_path = os.path.join(
compat.as_bytes(export_dir_path), compat.as_bytes(basename))
tf.logging.info("Copying asset %s to path %s.", source_filepath, new_path)
if gfile.Exists(new_path):
# Guard against being restarted while copying assets, and the file
# existing and being in an unknown state.
# TODO(b/28676216): Do some file checks before deleting.
tf.logging.info("Removing file %s.", new_path)
gfile.Remove(new_path)
tf.gfile.Copy(source_filepath, new_path)
def regression_signature(input_tensor, output_tensor):
"""Creates a regression signature.
Args:
input_tensor: Tensor specifying the input to a graph.
output_tensor: Tensor specifying the output of a graph.
Returns:
A Signature message.
"""
signature = manifest_pb2.Signature()
signature.regression_signature.input.tensor_name = input_tensor.name
signature.regression_signature.output.tensor_name = output_tensor.name
return signature
def classification_signature(input_tensor,
classes_tensor=None,
scores_tensor=None):
"""Creates a classification signature.
Args:
input_tensor: Tensor specifying the input to a graph.
classes_tensor: Tensor specifying the output classes of a graph.
scores_tensor: Tensor specifying the scores of the output classes.
Returns:
A Signature message.
"""
signature = manifest_pb2.Signature()
signature.classification_signature.input.tensor_name = input_tensor.name
if classes_tensor is not None:
signature.classification_signature.classes.tensor_name = classes_tensor.name
if scores_tensor is not None:
signature.classification_signature.scores.tensor_name = scores_tensor.name
return signature
def generic_signature(name_tensor_map):
"""Creates a generic signature of name to Tensor name.
Args:
name_tensor_map: Map from logical name to Tensor.
Returns:
A Signature message.
"""
signature = manifest_pb2.Signature()
for name, tensor in six.iteritems(name_tensor_map):
signature.generic_signature.map[name].tensor_name = tensor.name
return signature
class Exporter(object):
"""Exporter helps package a TensorFlow model for serving.
Args:
saver: Saver object.
"""
def __init__(self, saver):
self._saver = saver
self._has_init = False
self._assets_to_copy = {}
def init(self,
graph_def=None,
init_op=None,
clear_devices=False,
default_graph_signature=None,
named_graph_signatures=None,
assets_collection=None,
assets_callback=gfile_copy_callback):
"""Initialization.
Args:
graph_def: A GraphDef message of the graph to be used in inference.
GraphDef of default graph is used when None.
init_op: Op to be used in initialization.
clear_devices: If device info of the graph should be cleared upon export.
default_graph_signature: Default signature of the graph.
named_graph_signatures: Map of named input/output signatures of the graph.
assets_collection: A collection of constant asset filepath tensors. If set
the assets will be exported into the asset directory.
assets_callback: callback with two argument called during export with the
list of files to copy and the asset path.
Raises:
RuntimeError: if init is called more than once.
TypeError: if init_op is not an Operation or None.
ValueError: if asset file path tensors are not non-empty constant string
scalar tensors.
"""
# Avoid Dangerous default value []
if named_graph_signatures is None:
named_graph_signatures = {}
assets = []
if assets_collection:
for asset_tensor in assets_collection:
asset_filepath = self._file_path_value(asset_tensor)
if not asset_filepath:
raise ValueError("invalid asset filepath tensor %s" % asset_tensor)
basename = os.path.basename(asset_filepath)
assets.append((basename, asset_tensor))
self._assets_to_copy[asset_filepath] = basename
if self._has_init:
raise RuntimeError("init should be called only once")
self._has_init = True
if graph_def or clear_devices:
copy = tf.GraphDef()
if graph_def:
copy.CopyFrom(graph_def)
else:
copy.CopyFrom(tf.get_default_graph().as_graph_def())
if clear_devices:
for node in copy.node:
node.device = ""
graph_any_buf = Any()
graph_any_buf.Pack(copy)
tf.add_to_collection(GRAPH_KEY, graph_any_buf)
if init_op:
if not isinstance(init_op, ops.Operation):
raise TypeError("init_op needs to be an Operation: %s" % init_op)
tf.add_to_collection(INIT_OP_KEY, init_op)
signatures_proto = manifest_pb2.Signatures()
if default_graph_signature:
signatures_proto.default_signature.CopyFrom(default_graph_signature)
for signature_name, signature in six.iteritems(named_graph_signatures):
signatures_proto.named_signatures[signature_name].CopyFrom(signature)
signatures_any_buf = Any()
signatures_any_buf.Pack(signatures_proto)
tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
for filename, tensor in assets:
asset = manifest_pb2.AssetFile()
asset.filename = filename
asset.tensor_binding.tensor_name = tensor.name
asset_any_buf = Any()
asset_any_buf.Pack(asset)
tf.add_to_collection(ASSETS_KEY, asset_any_buf)
self._assets_callback = assets_callback
def export(self,
export_dir_base,
global_step_tensor,
sess=None,
exports_to_keep=None):
"""Exports the model.
Args:
export_dir_base: A string path to the base export dir.
global_step_tensor: An Tensor or tensor name providing the
global step counter to append to the export directory path and set
in the manifest version.
sess: A Session to use to save the parameters.
exports_to_keep: a gc.Path filter function used to determine the set of
exports to keep. If set to None, all versions will be kept.
Returns:
The string path to the exported directory.
Raises:
RuntimeError: if init is not called.
RuntimeError: if the export would overwrite an existing directory.
"""
if not self._has_init:
raise RuntimeError("init must be called first")
global_step = training_util.global_step(sess, global_step_tensor)
export_dir = os.path.join(
compat.as_bytes(export_dir_base),
compat.as_bytes(VERSION_FORMAT_SPECIFIER % global_step))
# Prevent overwriting on existing exports which could lead to bad/corrupt
# storage and loading of models. This is an important check that must be
# done before any output files or directories are created.
if gfile.Exists(export_dir):
raise RuntimeError("Overwriting exports can cause corruption and are "
"not allowed. Duplicate export dir: %s" % export_dir)
# Output to a temporary directory which is atomically renamed to the final
# directory when complete.
tmp_export_dir = compat.as_text(export_dir) + "-tmp"
gfile.MakeDirs(tmp_export_dir)
self._saver.save(sess,
os.path.join(
compat.as_text(tmp_export_dir),
compat.as_text(EXPORT_BASE_NAME)),
meta_graph_suffix=EXPORT_SUFFIX_NAME)
# Run the asset callback.
if self._assets_callback and self._assets_to_copy:
assets_dir = os.path.join(
compat.as_bytes(tmp_export_dir), compat.as_bytes(ASSETS_DIRECTORY))
gfile.MakeDirs(assets_dir)
self._assets_callback(self._assets_to_copy, assets_dir)
# TODO(b/27794910): Delete *checkpoint* file before rename.
gfile.Rename(tmp_export_dir, export_dir)
if exports_to_keep:
# create a simple parser that pulls the export_version from the directory.
def parser(path):
match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
if not match:
return None
return path._replace(export_version=int(match.group(1)))
paths_to_delete = gc.negation(exports_to_keep)
for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
gfile.DeleteRecursively(p.path)
return export_dir
def _file_path_value(self, path_tensor):
"""Returns the filepath value stored in constant `path_tensor`."""
if not isinstance(path_tensor, tf.Tensor):
raise TypeError("tensor is not a Tensor")
if path_tensor.op.type != "Const":
raise TypeError("Only constants tensor are supported")
if path_tensor.dtype != tf.string:
raise TypeError("File paths should be string")
str_value = path_tensor.op.get_attr("value").string_val
if len(str_value) != 1:
raise TypeError("Only scalar tensors are supported")
return str_value[0]
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for exporter.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter
from tensorflow.contrib.session_bundle import gc
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
FLAGS = flags.FLAGS
GLOBAL_STEP = 222
def tearDownModule():
gfile.DeleteRecursively(tf.test.get_temp_dir())
class SaveRestoreShardedTest(tf.test.TestCase):
def doBasicsOneExportPath(self,
export_path,
clear_devices=False,
global_step=GLOBAL_STEP,
sharded=True):
# Build a graph with 2 parameter nodes on different devices.
tf.reset_default_graph()
with tf.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
# v2 is an unsaved variable derived from v0 and v1. It is used to
# exercise the ability to run an init op when restoring a graph.
with sess.graph.device("/cpu:0"):
v0 = tf.Variable(10, name="v0")
with sess.graph.device("/cpu:1"):
v1 = tf.Variable(20, name="v1")
v2 = tf.Variable(1, name="v2", trainable=False, collections=[])
assign_v2 = tf.assign(v2, tf.add(v0, v1))
init_op = tf.group(assign_v2, name="init_op")
tf.add_to_collection("v", v0)
tf.add_to_collection("v", v1)
tf.add_to_collection("v", v2)
global_step_tensor = tf.Variable(global_step, name="global_step")
named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1}
signatures = {
"foo": exporter.regression_signature(input_tensor=v0,
output_tensor=v1),
"generic": exporter.generic_signature(named_tensor_bindings)
}
asset_filepath_orig = os.path.join(tf.test.get_temp_dir(), "hello42.txt")
asset_file = tf.constant(asset_filepath_orig, name="filename42")
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file)
with gfile.FastGFile(asset_filepath_orig, "w") as f:
f.write("your data here")
assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)
ignored_asset = os.path.join(tf.test.get_temp_dir(), "ignored.txt")
with gfile.FastGFile(ignored_asset, "w") as f:
f.write("additional data here")
tf.initialize_all_variables().run()
# Run an export.
save = tf.train.Saver({"v0": v0,
"v1": v1},
restore_sequentially=True,
sharded=sharded)
export = exporter.Exporter(save)
export.init(sess.graph.as_graph_def(),
init_op=init_op,
clear_devices=clear_devices,
default_graph_signature=exporter.classification_signature(
input_tensor=v0),
named_graph_signatures=signatures,
assets_collection=assets_collection)
export.export(export_path,
global_step_tensor,
sess,
exports_to_keep=gc.largest_export_versions(2))
# Restore graph.
compare_def = tf.get_default_graph().as_graph_def()
tf.reset_default_graph()
with tf.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
save = tf.train.import_meta_graph(
os.path.join(export_path, exporter.VERSION_FORMAT_SPECIFIER %
global_step, exporter.META_GRAPH_DEF_FILENAME))
self.assertIsNotNone(save)
meta_graph_def = save.export_meta_graph()
collection_def = meta_graph_def.collection_def
# Validate custom graph_def.
graph_def_any = collection_def[exporter.GRAPH_KEY].any_list.value
self.assertEquals(len(graph_def_any), 1)
graph_def = tf.GraphDef()
graph_def_any[0].Unpack(graph_def)
if clear_devices:
for node in compare_def.node:
node.device = ""
self.assertProtoEquals(compare_def, graph_def)
# Validate init_op.
init_ops = collection_def[exporter.INIT_OP_KEY].node_list.value
self.assertEquals(len(init_ops), 1)
self.assertEquals(init_ops[0], "init_op")
# Validate signatures.
signatures_any = collection_def[exporter.SIGNATURES_KEY].any_list.value
self.assertEquals(len(signatures_any), 1)
signatures = manifest_pb2.Signatures()
signatures_any[0].Unpack(signatures)
default_signature = signatures.default_signature
self.assertEqual(
default_signature.classification_signature.input.tensor_name, "v0:0")
bindings = signatures.named_signatures["generic"].generic_signature.map
self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
read_foo_signature = (
signatures.named_signatures["foo"].regression_signature)
self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")
# Validate the assets.
assets_any = collection_def[exporter.ASSETS_KEY].any_list.value
self.assertEquals(len(assets_any), 1)
asset = manifest_pb2.AssetFile()
assets_any[0].Unpack(asset)
assets_path = os.path.join(export_path,
exporter.VERSION_FORMAT_SPECIFIER %
global_step, exporter.ASSETS_DIRECTORY,
"hello42.txt")
asset_contents = gfile.GFile(assets_path).read()
self.assertEqual(asset_contents, "your data here")
self.assertEquals("hello42.txt", asset.filename)
self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
ignored_asset_path = os.path.join(export_path,
exporter.VERSION_FORMAT_SPECIFIER %
global_step, exporter.ASSETS_DIRECTORY,
"ignored.txt")
self.assertFalse(gfile.Exists(ignored_asset_path))
# Validate graph restoration.
if sharded:
save.restore(sess,
os.path.join(
export_path, exporter.VERSION_FORMAT_SPECIFIER %
global_step, exporter.VARIABLES_FILENAME_PATTERN))
else:
save.restore(sess,
os.path.join(
export_path, exporter.VERSION_FORMAT_SPECIFIER %
global_step, exporter.VARIABLES_FILENAME))
self.assertEqual(10, tf.get_collection("v")[0].eval())
self.assertEqual(20, tf.get_collection("v")[1].eval())
tf.get_collection(exporter.INIT_OP_KEY)[0].run()
self.assertEqual(30, tf.get_collection("v")[2].eval())
def testDuplicateExportRaisesError(self):
export_path = os.path.join(tf.test.get_temp_dir(), "export_duplicates")
self.doBasicsOneExportPath(export_path)
self.assertRaises(RuntimeError, self.doBasicsOneExportPath, export_path)
def testBasics(self):
export_path = os.path.join(tf.test.get_temp_dir(), "export")
self.doBasicsOneExportPath(export_path)
def testBasicsNoShard(self):
export_path = os.path.join(tf.test.get_temp_dir(), "export_no_shard")
self.doBasicsOneExportPath(export_path, sharded=False)
def testClearDevice(self):
export_path = os.path.join(tf.test.get_temp_dir(), "export_clear_device")
self.doBasicsOneExportPath(export_path, clear_devices=True)
def testGC(self):
export_path = os.path.join(tf.test.get_temp_dir(), "gc")
self.doBasicsOneExportPath(export_path, global_step=100)
self.assertEquals(gfile.ListDirectory(export_path), ["00000100"])
self.doBasicsOneExportPath(export_path, global_step=101)
self.assertEquals(
sorted(gfile.ListDirectory(export_path)), ["00000100", "00000101"])
self.doBasicsOneExportPath(export_path, global_step=102)
self.assertEquals(
sorted(gfile.ListDirectory(export_path)), ["00000101", "00000102"])
if __name__ == "__main__":
tf.test.main()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""System for specifying garbage collection (GC) of path based data.
This framework allows for GC of data specified by path names, for example files
on disk. gc.Path objects each represent a single item stored at a path and may
be a base directory,
/tmp/exports/0/...
/tmp/exports/1/...
...
or a fully qualified file,
/tmp/train-1.ckpt
/tmp/train-2.ckpt
...
A gc filter function takes and returns a list of gc.Path items. Filter
functions are responsible for selecting Path items for preservation or deletion.
Note that functions should always return a sorted list.
For example,
base_dir = "/tmp"
# create the directories
for e in xrange(10):
os.mkdir("%s/%d" % (base_dir, e), 0o755)
# create a simple parser that pulls the export_version from the directory
def parser(path):
match = re.match("^" + base_dir + "/(\\d+)$", path.path)
if not match:
return None
return path._replace(export_version=int(match.group(1)))
path_list = gc.get_paths("/tmp", parser) # contains all ten Paths
every_fifth = gc.mod_export_version(5)
print every_fifth(path_list) # shows ["/tmp/0", "/tmp/5"]
largest_three = gc.largest_export_versions(3)
print largest_three(all_paths) # shows ["/tmp/7", "/tmp/8", "/tmp/9"]
both = gc.union(every_fifth, largest_three)
print both(all_paths) # shows ["/tmp/0", "/tmp/5",
# "/tmp/7", "/tmp/8", "/tmp/9"]
# delete everything not in 'both'
to_delete = gc.negation(both)
for p in to_delete(all_paths):
gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2",
# "/tmp/3", "/tmp/4", "/tmp/6",
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import heapq
import math
import os
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.platform import gfile
Path = collections.namedtuple('Path', 'path export_version')
def largest_export_versions(n):
"""Creates a filter that keeps the largest n export versions.
Args:
n: number of versions to keep.
Returns:
A filter function that keeps the n largest paths.
"""
def keep(paths):
heap = []
for idx, path in enumerate(paths):
if path.export_version:
heapq.heappush(heap, (path.export_version, idx))
keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]
return sorted(keepers)
return keep
def one_of_every_n_export_versions(n):
"""Creates a filter that keeps one of every n export versions.
Args:
n: interval size.
Returns:
A filter function that keeps exactly one path from each interval
[0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an
interval the largest is kept.
"""
def keep(paths):
keeper_map = {} # map from interval to largest path seen in that interval
for p in paths:
if p.export_version is None:
# Skip missing export_versions.
continue
# Find the interval (with a special case to map export_version = 0 to
# interval 0.
interval = math.floor(
(p.export_version - 1) / n) if p.export_version else 0
existing = keeper_map.get(interval, None)
if (not existing) or (existing.export_version < p.export_version):
keeper_map[interval] = p
return sorted(keeper_map.values())
return keep
def mod_export_version(n):
"""Creates a filter that keeps every export that is a multiple of n.
Args:
n: step size.
Returns:
A filter function that keeps paths where export_version % n == 0.
"""
def keep(paths):
keepers = []
for p in paths:
if p.export_version % n == 0:
keepers.append(p)
return sorted(keepers)
return keep
def union(lf, rf):
"""Creates a filter that keeps the union of two filters.
Args:
lf: first filter
rf: second filter
Returns:
A filter function that keeps the n largest paths.
"""
def keep(paths):
l = set(lf(paths))
r = set(rf(paths))
return sorted(list(l|r))
return keep
def negation(f):
"""Negate a filter.
Args:
f: filter function to invert
Returns:
A filter function that returns the negation of f.
"""
def keep(paths):
l = set(paths)
r = set(f(paths))
return sorted(list(l-r))
return keep
def get_paths(base_dir, parser):
"""Gets a list of Paths in a given directory.
Args:
base_dir: directory.
parser: a function which gets the raw Path and can augment it with
information such as the export_version, or ignore the path by returning
None. An example parser may extract the export version from a path
such as "/tmp/exports/100" an another may extract from a full file
name such as "/tmp/checkpoint-99.out".
Returns:
A list of Paths contained in the base directory with the parsing function
applied.
By default the following fields are populated,
- Path.path
The parsing function is responsible for populating,
- Path.export_version
"""
raw_paths = gfile.ListDirectory(base_dir)
paths = []
for r in raw_paths:
p = parser(Path(os.path.join(base_dir, r), None))
if p:
paths.append(p)
return sorted(paths)
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for session_bundle.gc."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.session_bundle import gc
from tensorflow.python.framework import test_util
from tensorflow.python.platform import gfile
def tearDownModule():
gfile.DeleteRecursively(tf.test.get_temp_dir())
class GcTest(test_util.TensorFlowTestCase):
def testLargestExportVersions(self):
paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
newest = gc.largest_export_versions(2)
n = newest(paths)
self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
def testModExportVersion(self):
paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
gc.Path("/foo", 9)]
mod = gc.mod_export_version(2)
self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
mod = gc.mod_export_version(3)
self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
def testOneOfEveryNExportVersions(self):
paths = [gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3),
gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7),
gc.Path("/foo", 8), gc.Path("/foo", 33)]
one_of = gc.one_of_every_n_export_versions(3)
self.assertEquals(one_of(paths),
[gc.Path("/foo", 3), gc.Path("/foo", 6),
gc.Path("/foo", 8), gc.Path("/foo", 33)])
def testOneOfEveryNExportVersionsZero(self):
# Zero is a special case since it gets rolled into the first interval.
# Test that here.
paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
one_of = gc.one_of_every_n_export_versions(3)
self.assertEquals(one_of(paths),
[gc.Path("/foo", 0), gc.Path("/foo", 5)])
def testUnion(self):
paths = []
for i in xrange(10):
paths.append(gc.Path("/foo", i))
f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
self.assertEquals(
f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3),
gc.Path("/foo", 6), gc.Path("/foo", 7),
gc.Path("/foo", 8), gc.Path("/foo", 9)])
def testNegation(self):
paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
gc.Path("/foo", 9)]
mod = gc.negation(gc.mod_export_version(2))
self.assertEquals(
mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
mod = gc.negation(gc.mod_export_version(3))
self.assertEquals(
mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
def testPathsWithParse(self):
base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse")
self.assertFalse(gfile.Exists(base_dir))
for p in xrange(3):
gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
# add a base_directory to ignore
gfile.MakeDirs(os.path.join(base_dir, "ignore"))
# create a simple parser that pulls the export_version from the directory.
def parser(path):
match = re.match("^" + base_dir + "/(\\d+)$", path.path)
if not match:
return None
return path._replace(export_version=int(match.group(1)))
self.assertEquals(
gc.get_paths(base_dir, parser=parser),
[gc.Path(os.path.join(base_dir, "0"), 0),
gc.Path(os.path.join(base_dir, "1"), 1),
gc.Path(os.path.join(base_dir, "2"), 2)])
if __name__ == "__main__":
tf.test.main()
syntax = "proto3";
package tensorflow.contrib;
// Signatures of model export.
message Signatures {
// Default signature of the graph.
// WARNING(break-tutorial-inline-code): The following code snippet is
// in-lined in tutorials, please update tutorial documents accordingly
// whenever code changes.
Signature default_signature = 1;
// Named signatures of the graph.
map<string, Signature> named_signatures = 2;
};
// A binding to a tensor including the name and, possibly in the future, type
// or other metadata. For example, this may specify whether a tensor supports
// batch vs single inference.
message TensorBinding {
// The name of the tensor to bind to.
string tensor_name = 1;
};
// An asset file or set of sharded files with the same name that will be bound
// to a tensor at init / session_bundle load time.
message AssetFile {
// The tensor to bind the asset filename to.
TensorBinding tensor_binding = 1;
// The filename within the assets directory. Note: does not include the base
// path or asset directory prefix. Base paths can and will change when models
// are deployed for serving.
string filename = 2;
}
// A Signature specifies the inputs and outputs of commonly used graphs.
message Signature {
oneof type {
RegressionSignature regression_signature = 1;
ClassificationSignature classification_signature = 2;
GenericSignature generic_signature = 3;
}
};
// RegressionSignature specifies a graph that takes an input and returns an
// output.
message RegressionSignature {
TensorBinding input = 1;
TensorBinding output = 2;
};
// ClassificationSignature specifies a graph that takes an input and returns
// classes and their scores.
// WARNING(break-tutorial-inline-code): The following code snippet is
// in-lined in tutorials, please update tutorial documents accordingly
// whenever code changes.
message ClassificationSignature {
TensorBinding input = 1;
TensorBinding classes = 2;
TensorBinding scores = 3;
};
// GenericSignature specifies a map from logical name to Tensor name.
// Typical application of GenericSignature is to use a single GenericSignature
// that includes all of the Tensor nodes and target names that may be useful at
// serving, analysis or debugging time. The recommended name for this signature
// in the ModelManifest is "generic_bindings".
message GenericSignature {
map<string, TensorBinding> map = 1;
};
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/session_bundle/session_bundle.h"
#include <string>
#include <utility>
#include <vector>
#include "google/protobuf/any.pb.h"
#include "tensorflow/contrib/session_bundle/manifest.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace contrib {
namespace {
// Create a session using the given options and load the graph.
Status CreateSessionFromGraphDef(
const tensorflow::SessionOptions& options, const GraphDef& graph,
std::unique_ptr<tensorflow::Session>* session) {
session->reset(NewSession(options));
return (*session)->Create(graph);
}
Status GetMetaGraphDefFromExport(const StringPiece export_dir,
tensorflow::MetaGraphDef* meta_graph_def) {
const string meta_graph_def_path =
tensorflow::io::JoinPath(export_dir, kMetaGraphDefFilename);
return ReadBinaryProto(Env::Default(), meta_graph_def_path, meta_graph_def);
}
// Creates a string tensor.
Tensor CreateStringTensor(const string& value) {
Tensor tensor(DT_STRING, TensorShape({}));
tensor.scalar<string>()() = value;
return tensor;
}
// Adds Assets related tensors (assets_dir and asset files) to the inputs.
void AddAssetsTensorsToInputs(const StringPiece export_dir,
const std::vector<AssetFile>& asset_files,
std::vector<std::pair<string, Tensor>>* inputs) {
if (!asset_files.empty()) {
for (auto& asset : asset_files) {
Tensor assets_file_tensor = CreateStringTensor(tensorflow::io::JoinPath(
tensorflow::io::JoinPath(export_dir, kAssetsDirectory),
asset.filename()));
inputs->push_back(
{asset.tensor_binding().tensor_name(), assets_file_tensor});
}
}
}
// Historically, model exporter(exporter.py) takes only saver with
// sharded=True, and therefore always exports checkpoint in pattern file names.
// In practice, instead of training from scratch and export directly, we
// usually want to restore from existing checkpoints and then export directly.
// To support such case, model exporter now supports reusing saver object
// restored from existing checkpoint, that may have sharded=False - it will
// then export checkpoint file in plain file name.
// This method is to support models exported by both types of saver object.
// The change is backward-compatible, therefore no changes are needed for
// existing model exports.
string GetVariablesFilename(const StringPiece export_dir) {
const char kVariablesFilename[] = "export";
const char kVariablesFilenamePattern[] = "export-\?\?\?\?\?-of-\?\?\?\?\?";
if (Env::Default()->FileExists(
tensorflow::io::JoinPath(export_dir, kVariablesFilename))) {
return tensorflow::io::JoinPath(export_dir, kVariablesFilename);
} else {
return tensorflow::io::JoinPath(export_dir, kVariablesFilenamePattern);
}
}
Status RunRestoreOp(const StringPiece export_dir,
const std::vector<AssetFile>& asset_files,
const StringPiece restore_op_name,
const StringPiece variables_filename_const_op_name,
tensorflow::Session* session) {
LOG(INFO) << "Running restore op for SessionBundle";
Tensor variables_tensor = CreateStringTensor(
GetVariablesFilename(export_dir));
std::vector<std::pair<string, Tensor>> inputs = {
{variables_filename_const_op_name.ToString(), variables_tensor}};
AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
return session->Run(inputs, {}, {restore_op_name.ToString()}, nullptr);
}
Status RunInitOp(const StringPiece export_dir,
const std::vector<AssetFile>& asset_files,
const StringPiece init_op_name, tensorflow::Session* session) {
LOG(INFO) << "Running init op for SessionBundle";
std::vector<std::pair<string, Tensor>> inputs;
AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
return session->Run(inputs, {}, {init_op_name.ToString()}, nullptr);
}
} // namespace
tensorflow::Status LoadSessionBundleFromPath(
const tensorflow::SessionOptions& options, const StringPiece export_dir,
SessionBundle* bundle) {
LOG(INFO) << "Attempting to load a SessionBundle from: " << export_dir;
TF_RETURN_IF_ERROR(
GetMetaGraphDefFromExport(export_dir, &(bundle->meta_graph_def)));
auto collection_def = bundle->meta_graph_def.collection_def();
if (collection_def.find(kGraphKey) != collection_def.end()) {
// Use serving graph_def in MetaGraphDef collection_def.
if (collection_def[kGraphKey].any_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one serving GraphDef in : ",
bundle->meta_graph_def.DebugString()));
}
tensorflow::GraphDef graph_def;
collection_def[kGraphKey].any_list().value(0).UnpackTo(&graph_def);
TF_RETURN_IF_ERROR(
CreateSessionFromGraphDef(options, graph_def, &bundle->session));
} else {
// Fallback to use the graph_def in the MetaGraphDef.
const tensorflow::GraphDef& graph_def = bundle->meta_graph_def.graph_def();
TF_RETURN_IF_ERROR(
CreateSessionFromGraphDef(options, graph_def, &bundle->session));
}
std::vector<AssetFile> asset_files;
auto any_assets = collection_def[kAssetsKey].any_list().value();
for (const auto any_asset : any_assets) {
AssetFile asset_file;
any_asset.UnpackTo(&asset_file);
asset_files.push_back(asset_file);
}
TF_RETURN_IF_ERROR(
RunRestoreOp(export_dir, asset_files,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
bundle->session.get()));
if (collection_def.find(kInitOpKey) != collection_def.end()) {
if (collection_def[kInitOpKey].node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one serving init op in : ",
bundle->meta_graph_def.DebugString()));
}
return RunInitOp(export_dir, asset_files,
collection_def[kInitOpKey].node_list().value(0),
bundle->session.get());
}
LOG(INFO) << "Done loading SessionBundle";
return Status::OK();
}
} // namespace contrib
} // namespace tensorflow
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Low-level functionality for setting up a inference Session.
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
#include <memory>
#include "tensorflow/contrib/session_bundle/manifest.pb.h"
#include "tensorflow/contrib/session_bundle/signature.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace contrib {
const char kMetaGraphDefFilename[] = "export.meta";
const char kAssetsDirectory[] = "assets";
const char kInitOpKey[] = "serving_init_op";
const char kAssetsKey[] = "serving_assets";
const char kGraphKey[] = "serving_graph";
// Data and objects loaded from a python Exporter export.
// WARNING(break-tutorial-inline-code): The following code snippet is
// in-lined in tutorials, please update tutorial documents accordingly
// whenever code changes.
struct SessionBundle {
std::unique_ptr<tensorflow::Session> session;
tensorflow::MetaGraphDef meta_graph_def;
};
// Loads a manifest and initialized session using the output of an Exporter
// using the format defined at go/tf-exporter.
tensorflow::Status LoadSessionBundleFromPath(
const tensorflow::SessionOptions& options,
const tensorflow::StringPiece export_dir, SessionBundle* bundle);
} // namespace contrib
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/session_bundle/session_bundle.h"
#include <string>
#include <utility>
#include <vector>
#include "google/protobuf/any.pb.h"
#include "tensorflow/contrib/session_bundle/test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace contrib {
namespace {
TEST(LoadSessionBundleFromPath, Basic) {
const string export_path = test_util::TestSrcDirPath(
"session_bundle/example/half_plus_two/00000123");
tensorflow::SessionOptions options;
SessionBundle bundle;
TF_ASSERT_OK(LoadSessionBundleFromPath(options, export_path, &bundle));
const string asset_path =
tensorflow::io::JoinPath(export_path, kAssetsDirectory);
// Validate the assets behavior.
std::vector<Tensor> path_outputs;
TF_ASSERT_OK(bundle.session->Run({}, {"filename1:0", "filename2:0"}, {},
&path_outputs));
ASSERT_EQ(2, path_outputs.size());
// Validate the two asset file tensors are set by the init_op and include the
// base_path and asset directory.
test::ExpectTensorEqual<string>(
test::AsTensor<string>(
{tensorflow::io::JoinPath(asset_path, "hello1.txt")},
TensorShape({})),
path_outputs[0]);
test::ExpectTensorEqual<string>(
test::AsTensor<string>(
{tensorflow::io::JoinPath(asset_path, "hello2.txt")},
TensorShape({})),
path_outputs[1]);
// Validate the half plus two behavior.
Tensor input = test::AsTensor<float>({0, 1, 2, 3}, TensorShape({4, 1}));
// Recover the Tensor names of our inputs and outputs.
auto collection_def = bundle.meta_graph_def.collection_def();
Signatures signatures;
ASSERT_EQ(1, collection_def[kSignaturesKey].any_list().value_size());
collection_def[kSignaturesKey].any_list().value(0).UnpackTo(&signatures);
ASSERT_TRUE(signatures.default_signature().has_regression_signature());
const tensorflow::contrib::RegressionSignature regression_signature =
signatures.default_signature().regression_signature();
const string input_name = regression_signature.input().tensor_name();
const string output_name = regression_signature.output().tensor_name();
std::vector<Tensor> outputs;
TF_ASSERT_OK(
bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
ASSERT_EQ(outputs.size(), 1);
test::ExpectTensorEqual<float>(
outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
}
TEST(LoadSessionBundleFromPath, BadExportPath) {
const string export_path = test_util::TestSrcDirPath("/tmp/bigfoot");
tensorflow::SessionOptions options;
options.target = "local";
SessionBundle bundle;
const auto status = LoadSessionBundleFromPath(options, export_path, &bundle);
ASSERT_FALSE(status.ok());
const string msg = status.ToString();
EXPECT_TRUE(msg.find("Not found") != std::string::npos) << msg;
}
} // namespace
} // namespace contrib
} // namespace tensorflow
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/session_bundle/signature.h"
#include <string>
#include <utility>
#include <vector>
#include "google/protobuf/any.pb.h"
#include "tensorflow/contrib/session_bundle/manifest.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace contrib {
namespace {
// Returns OK if the input and output batch sizes match.
Status BatchSizesMatch(const Tensor& input, const Tensor& output) {
// Ensure the number of outputs match the number of inputs.
if (input.dim_size(0) != output.dim_size(0)) {
return errors::Internal(
strings::StrCat("Input batch size did not match output batch size: ",
input.dim_size(0), " vs. ", output.dim_size(0)));
}
return Status::OK();
}
} // namespace
Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
Signatures* signatures) {
auto collection_def = meta_graph_def.collection_def();
auto any_list = collection_def[kSignaturesKey].any_list();
if (any_list.value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one signatures proto in : ",
meta_graph_def.DebugString()));
}
any_list.value(0).UnpackTo(signatures);
return Status::OK();
}
Status SetSignatures(const Signatures& signatures,
tensorflow::MetaGraphDef* meta_graph_def) {
auto& collection_def = *(meta_graph_def->mutable_collection_def());
auto* any_list = collection_def[kSignaturesKey].mutable_any_list();
any_list->mutable_value()->Clear();
any_list->mutable_value()->Add()->PackFrom(signatures);
return Status::OK();
}
Status GetClassificationSignature(
const tensorflow::MetaGraphDef& meta_graph_def,
ClassificationSignature* signature) {
Signatures signatures;
TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
if (!signatures.has_default_signature()) {
return errors::FailedPrecondition(strings::StrCat(
"Expected a default signature in: ", signatures.DebugString()));
}
if (!signatures.default_signature().has_classification_signature()) {
return errors::FailedPrecondition(
strings::StrCat("Expected a classification signature in: ",
signatures.default_signature().DebugString()));
}
*signature = signatures.default_signature().classification_signature();
return Status::OK();
}
Status GetNamedClassificationSignature(
const string& name, const tensorflow::MetaGraphDef& meta_graph_def,
ClassificationSignature* signature) {
Signatures signatures;
TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
const auto& it = signatures.named_signatures().find(name);
if (it == signatures.named_signatures().end()) {
return errors::NotFound(strings::StrCat("Missing signature named \"", name,
"\" in: ",
signatures.DebugString()));
}
if (!it->second.has_classification_signature()) {
return errors::FailedPrecondition(
strings::StrCat("Expected a classification signature for name \"", name,
"\" in: ", it->second.DebugString()));
}
*signature = it->second.classification_signature();
return Status::OK();
}
Status RunClassification(const ClassificationSignature& signature,
const Tensor& input, Session* session, Tensor* classes,
Tensor* scores) {
std::vector<string> output_tensor_names;
if (classes) {
output_tensor_names.push_back(signature.classes().tensor_name());
}
if (scores) {
output_tensor_names.push_back(signature.scores().tensor_name());
}
// Run the graph with our inputs and outputs.
std::vector<Tensor> outputs;
const Status run_status =
session->Run({{signature.input().tensor_name(), input}},
output_tensor_names, {}, &outputs);
if (!run_status.ok()) {
return run_status;
}
// Ensure the output is shaped how we expect.
// There should be one string Tensor of shape,
// [batch_size, num_recommendations].
if (outputs.size() != output_tensor_names.size()) {
return errors::Internal(
strings::StrCat("Expected ", output_tensor_names.size(),
" output tensor(s). Got: ", outputs.size()));
}
if (classes) {
*classes = outputs[0];
TF_RETURN_IF_ERROR(BatchSizesMatch(input, *classes));
}
if (scores) {
*scores = outputs[classes ? 1 : 0];
TF_RETURN_IF_ERROR(BatchSizesMatch(input, *scores));
}
return Status::OK();
}
Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def,
RegressionSignature* signature) {
Signatures signatures;
TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
if (!signatures.has_default_signature()) {
return errors::FailedPrecondition(strings::StrCat(
"Expected a default signature in: ", signatures.DebugString()));
}
if (!signatures.default_signature().has_regression_signature()) {
return errors::FailedPrecondition(
strings::StrCat("Expected a regression signature in: ",
signatures.default_signature().DebugString()));
}
*signature = signatures.default_signature().regression_signature();
return Status::OK();
}
Status RunRegression(const RegressionSignature& signature,
const Tensor& regression_input, Session* session,
Tensor* regression_output) {
std::vector<string> output_tensor_names;
if (regression_output) {
output_tensor_names.push_back(signature.output().tensor_name());
}
// Run the graph with our inputs and outputs.
std::vector<Tensor> outputs;
const Status run_status =
session->Run({{signature.input().tensor_name(), regression_input}},
output_tensor_names, {}, &outputs);
if (!run_status.ok()) {
return run_status;
}
// Ensure the regression score output is shaped how we expect.
// There should be one float Tensor of shape,
// [batch_size, num_recommendations].
if (outputs.size() != output_tensor_names.size()) {
return errors::Internal(
strings::StrCat("Expected ", output_tensor_names.size(),
" output tensor(s). Got: ", outputs.size()));
}
if (regression_output) {
*regression_output = outputs[0];
TF_RETURN_IF_ERROR(BatchSizesMatch(regression_input, *regression_output));
}
return Status::OK();
}
Status GetGenericSignature(const string& name,
const tensorflow::MetaGraphDef& meta_graph_def,
GenericSignature* signature) {
Signatures signatures;
TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
const auto& it = signatures.named_signatures().find(name);
if (it == signatures.named_signatures().end()) {
return errors::InvalidArgument(
strings::StrCat("Missing generic signature named \"", name, "\" in ",
signatures.DebugString()));
}
if (!it->second.has_generic_signature()) {
return errors::InvalidArgument(strings::StrCat(
"Expected a generic signature: ", it->second.DebugString()));
}
*signature = it->second.generic_signature();
return Status::OK();
}
Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def,
Signature* default_signature) {
Signatures signatures;
TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
*default_signature = signatures.default_signature();
return Status::OK();
}
Status GetNamedSignature(const string& name,
const tensorflow::MetaGraphDef& meta_graph_def,
Signature* signature) {
Signatures signatures;
TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
const auto& it = signatures.named_signatures().find(name);
if (it == signatures.named_signatures().end()) {
return errors::NotFound(strings::StrCat("Missing signature named \"", name,
"\" in: ",
signatures.DebugString()));
}
*signature = it->second;
return Status::OK();
}
Status BindGenericInputs(const GenericSignature& signature,
const std::vector<std::pair<string, Tensor>>& inputs,
std::vector<std::pair<string, Tensor>>* bound_inputs) {
const protobuf::Map<string, contrib::TensorBinding>& bindings =
signature.map();
for (const auto& entry : inputs) {
const auto mapped = bindings.find(entry.first);
if (mapped == bindings.end()) {
return errors::NotFound(
strings::StrCat("Could not find generic binding for: ", entry.first));
}
bound_inputs->push_back({mapped->second.tensor_name(), entry.second});
}
return Status::OK();
}
Status BindGenericNames(const GenericSignature& signature,
const std::vector<string>& input_names,
std::vector<string>* bound_names) {
const protobuf::Map<string, contrib::TensorBinding>& bindings =
signature.map();
for (const string& entry : input_names) {
const auto mapped = bindings.find(entry);
if (mapped == bindings.end()) {
return errors::NotFound(
strings::StrCat("Could not find generic binding for: ", entry));
}
bound_names->push_back(mapped->second.tensor_name());
}
return Status::OK();
}
} // namespace contrib
} // namespace tensorflow
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Helpers for working with TensorFlow exports and their signatures.
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/contrib/session_bundle/manifest.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace contrib {
const char kSignaturesKey[] = "serving_signatures";
// Get Signatures from a MetaGraphDef.
Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
Signatures* signatures);
// (Re)set Signatures in a MetaGraphDef.
Status SetSignatures(const Signatures& signatures,
tensorflow::MetaGraphDef* meta_graph_def);
// Gets a ClassificationSignature from a MetaGraphDef's default signature.
// Returns an error if the default signature is not a ClassificationSignature,
// or does not exist.
Status GetClassificationSignature(
const tensorflow::MetaGraphDef& meta_graph_def,
ClassificationSignature* signature);
// Gets a named ClassificationSignature from a MetaGraphDef.
// Returns an error if a ClassificationSignature with the given name does
// not exist.
Status GetNamedClassificationSignature(
const string& name, const tensorflow::MetaGraphDef& meta_graph_def,
ClassificationSignature* signature);
// Gets a RegressionSignature from a MetaGraphDef's default signature.
// Returns an error if the default signature is not a RegressionSignature,
// or does not exist.
Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def,
RegressionSignature* signature);
// Runs a classification using the provided signature and initialized Session.
// input: input batch of items to classify
// classes: output batch of classes; may be null if not needed
// scores: output batch of scores; may be null if not needed
// Validates sizes of the inputs and outputs are consistent (e.g., input
// batch size equals output batch sizes).
// Does not do any type validation.
Status RunClassification(const ClassificationSignature& signature,
const Tensor& input, Session* session, Tensor* classes,
Tensor* scores);
// Runs regression using the provided signature and initialized Session.
// input: input batch of items to run the regression model against
// output: output targets
// Validates sizes of the inputs and outputs are consistent (e.g., input
// batch size equals output batch sizes).
// Does not do any type validation.
Status RunRegression(const RegressionSignature& signature, const Tensor& input,
Session* session, Tensor* output);
// Gets the named GenericSignature from a MetaGraphDef.
// Returns an error if a GenericSignature with the given name does not exist.
Status GetGenericSignature(const string& name,
const tensorflow::MetaGraphDef& meta_graph_def,
GenericSignature* signature);
// Gets the default signature from a MetaGraphDef.
Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def,
Signature* default_signature);
// Gets a named Signature from a MetaGraphDef.
// Returns an error if a Signature with the given name does not exist.
Status GetNamedSignature(const string& name,
const tensorflow::MetaGraphDef& meta_graph_def,
Signature* default_signature);
// Binds TensorFlow inputs specified by the caller using the logical names
// specified at Graph export time, to the actual Graph names.
// Returns an error if any of the inputs do not have a binding in the export's
// MetaGraphDef.
Status BindGenericInputs(const GenericSignature& signature,
const std::vector<std::pair<string, Tensor>>& inputs,
std::vector<std::pair<string, Tensor>>* bound_inputs);
// Binds the input names specified by the caller using the logical names
// specified at Graph export time, to the actual Graph names. This is useful
// for binding names of both the TensorFlow output tensors and target nodes,
// with the latter (target nodes) being optional and rarely used (if ever) at
// serving time.
// Returns an error if any of the input names do not have a binding in the
// export's MetaGraphDef.
Status BindGenericNames(const GenericSignature& signature,
const std::vector<string>& input_names,
std::vector<string>* bound_names);
} // namespace contrib
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
此差异已折叠。
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/session_bundle/test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace contrib {
namespace test_util {
string TestSrcDirPath(const string& relative_path) {
const string base_path = tensorflow::testing::TensorFlowSrcRoot();
const string contrib_path = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), "/contrib");
return tensorflow::io::JoinPath(contrib_path, relative_path);
}
} // namespace test_util
} // namespace contrib
} // namespace tensorflow
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
#include <string>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace contrib {
namespace test_util {
// Creates an absolute test srcdir path to the linked in runfiles given a path
// relative to third_party/tensorflow/contrib/.
// e.g. relative path = "session_bundle/example".
string TestSrcDirPath(const string& relative_path);
} // namespace test_util
} // namespace contrib
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
......@@ -43,17 +43,18 @@ py_library(
],
)
py_test(
name = "learning_test",
srcs = ["python/slim/learning_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/slim",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
# TODO(nsilberman): Fix this test and re-enable.
#py_test(
# name = "learning_test",
# srcs = ["python/slim/learning_test.py"],
# srcs_version = "PY2AND3",
# deps = [
# "//tensorflow:tensorflow_py",
# "//tensorflow/contrib/slim",
# "//tensorflow/python:framework_test_lib",
# "//tensorflow/python:platform_test",
# ],
#)
py_library(
name = "queues",
......
......@@ -53,6 +53,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_copts",
"tf_cc_test",
"tf_cc_tests",
......@@ -119,7 +120,6 @@ tf_proto_library_cc(
srcs = ["protobuf/worker_service.proto"],
has_services = 1,
cc_api_version = 2,
cc_grpc_version = 1,
cc_libs = [":worker_proto_cc"],
cc_stubby_versions = ["2"],
visibility = [
......@@ -142,7 +142,6 @@ tf_proto_library_cc(
srcs = ["protobuf/master_service.proto"],
has_services = 1,
cc_api_version = 2,
cc_grpc_version = 1,
cc_libs = [":master_proto_cc"],
cc_stubby_versions = ["2"],
visibility = [
......@@ -667,13 +666,7 @@ filegroup(
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
cc_library(
name = "android_tensorflow_lib_lite",
srcs =
select({
"//tensorflow:android": [
"//tensorflow/core:android_srcs",
],
"//conditions:default": [],
}),
srcs = if_android(["//tensorflow/core:android_srcs"]),
copts = tf_copts() + ["-Os"],
tags = [
"manual",
......@@ -691,12 +684,7 @@ cc_library(
# binary size (by packaging a reduced operator set) is a concern.
cc_library(
name = "android_tensorflow_lib",
srcs = select({
"//tensorflow:android": [
":android_op_registrations_and_gradients",
],
"//conditions:default": [],
}),
srcs = if_android([":android_op_registrations_and_gradients"]),
copts = tf_copts(),
tags = [
"manual",
......@@ -718,13 +706,7 @@ cc_library(
# registration of ops to prune code size.
cc_library(
name = "android_tensorflow_lib_selective_registration",
srcs =
select({
"//tensorflow:android": [
"//tensorflow/core:android_srcs",
],
"//conditions:default": [],
}),
srcs = if_android(["//tensorflow/core:android_srcs"]),
copts = tf_copts() + [
"-Os",
"-DSUPPORT_SELECTIVE_REGISTRATION",
......@@ -758,6 +740,41 @@ filegroup(
visibility = ["//visibility:public"],
)
filegroup(
name = "android_test_srcs",
# TODO(andrewharp/nhua):
# make more test-related sources portable e.g. "platform/test.cc",
srcs = [
"//tensorflow/core:framework/fake_input.cc",
"//tensorflow/core:framework/fake_input.h",
"//tensorflow/core:framework/tensor_testutil.cc",
"//tensorflow/core:framework/tensor_testutil.h",
"//tensorflow/core:platform/test.h",
"//tensorflow/core:util/reporter.cc",
"//tensorflow/core:util/reporter.h",
],
visibility = ["//visibility:public"],
)
# Portable library providing testing functionality for Tensorflow.
cc_library(
name = "android_tensorflow_test_lib",
testonly = 1,
srcs = if_android([":android_test_srcs"]),
copts = tf_copts() + ["-Os"],
tags = [
"manual",
"notap",
],
visibility = ["//visibility:public"],
deps = [
":android_tensorflow_lib",
":protos_cc",
"//tensorflow/core/platform/default/build_config:gtest",
"//third_party/eigen3",
],
)
# -----------------------------------------------------------------------------
# Libraries with GPU facilities that are useful for writing kernels.
cc_library(
......
......@@ -258,7 +258,6 @@ tf_cuda_cc_tests(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:master_service_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
......@@ -294,7 +293,6 @@ tf_cuda_cc_tests(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:master_service_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
......
......@@ -74,11 +74,11 @@ cc_library(
"@grpc//:grpc++_unsecure",
":grpc_client_cq_tag",
":grpc_util",
":grpc_worker_service_impl",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core:worker_service_proto_cc",
"//tensorflow/core/distributed_runtime:worker_cache_logger",
"//tensorflow/core/distributed_runtime:worker_interface",
],
......@@ -137,14 +137,14 @@ cc_library(
":async_service_interface",
":grpc_call",
":grpc_util",
":grpc_worker_service_impl",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_lib",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core:worker_service_proto_cc",
"//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:worker_cache",
......@@ -153,16 +153,26 @@ cc_library(
],
)
cc_library(
name = "grpc_worker_service_impl",
srcs = ["grpc_worker_service_impl.cc"],
hdrs = ["grpc_worker_service_impl.h"],
deps = [
"@grpc//:grpc++_unsecure",
":grpc_serialization_traits",
"//tensorflow/core:worker_proto_cc",
],
)
cc_library(
name = "grpc_remote_master",
srcs = ["grpc_remote_master.cc"],
hdrs = ["grpc_remote_master.h"],
deps = [
"@grpc//:grpc++_unsecure",
":grpc_master_service_impl",
":grpc_util",
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:master_service_proto_cc",
"//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:master_interface",
],
......@@ -177,16 +187,35 @@ cc_library(
"@grpc//:grpc++_unsecure",
":async_service_interface",
":grpc_call",
":grpc_master_service_impl",
":grpc_util",
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:master_service_proto_cc",
"//tensorflow/core/distributed_runtime:master",
"//tensorflow/core/distributed_runtime:master_interface",
],
alwayslink = 1,
)
cc_library(
name = "grpc_master_service_impl",
srcs = ["grpc_master_service_impl.cc"],
hdrs = ["grpc_master_service_impl.h"],
deps = [
"@grpc//:grpc++_unsecure",
":grpc_serialization_traits",
"//tensorflow/core:master_proto_cc",
],
)
cc_library(
name = "grpc_serialization_traits",
srcs = [],
hdrs = ["grpc_serialization_traits.h"],
deps = [
"@grpc//:grpc++_unsecure",
],
)
cc_library(
name = "rpc_rendezvous_mgr",
srcs = ["rpc_rendezvous_mgr.cc"],
......
......@@ -36,11 +36,11 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/master.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
......
......@@ -17,11 +17,11 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/lib/core/errors.h"
......@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/protobuf/worker_service.grpc.pb.h"
namespace tensorflow {
......
......@@ -317,6 +317,54 @@ TEST(GrpcSessionTest, MultiDevices) {
}
}
TEST(GrpcSessionTest, LargeTensorSend) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
Graph graph(OpRegistry::Global());
// Define a 3 GB fill result.
Tensor fill_shape_tensor(DT_INT32, TensorShape({4}));
fill_shape_tensor.vec<int32>()(0) = 1;
fill_shape_tensor.vec<int32>()(1) = 256;
fill_shape_tensor.vec<int32>()(2) = 1024;
fill_shape_tensor.vec<int32>()(3) = 1024;
Node* fill_shape_node = test::graph::Constant(&graph, fill_shape_tensor);
Tensor fill_val_tensor(DT_FLOAT, TensorShape({}));
fill_val_tensor.flat<float>()(0) = 1.0;
Node* fill_val_node = test::graph::Constant(&graph, fill_val_tensor);
Node* fill_node =
test::graph::Binary(&graph, "Fill", fill_shape_node, fill_val_node);
Tensor max_axes_tensor(DT_INT32, TensorShape({4}));
max_axes_tensor.vec<int32>()(0) = 0;
max_axes_tensor.vec<int32>()(1) = 1;
max_axes_tensor.vec<int32>()(2) = 2;
max_axes_tensor.vec<int32>()(3) = 3;
Node* max_axes_node = test::graph::Constant(&graph, max_axes_tensor);
Node* max_node = test::graph::Reduce(&graph, "Max", fill_node, max_axes_node);
GraphDef def;
test::graph::ToGraphDef(&graph, &def);
SetDevice(&def, fill_node->name(), cluster->devices()[0].name());
SetDevice(&def, fill_node->name(), cluster->devices()[1].name());
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1000)));
ASSERT_TRUE(session != nullptr);
TF_CHECK_OK(session->Create(def));
{
std::vector<Tensor> outputs;
TF_CHECK_OK(session->Run({}, {max_node->name()}, {}, &outputs));
ASSERT_EQ(1, outputs.size());
IsSingleFloatValue(outputs[0], 1.0);
}
TF_CHECK_OK(session->Close());
}
TEST(GrpcSessionTest, MultiDevices_String) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster));
......
......@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/cancellation.h"
......@@ -40,8 +41,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/protobuf/worker_service.grpc.pb.h"
#include "tensorflow/core/protobuf/worker_service.pb.h"
namespace tensorflow {
......
......@@ -51,6 +51,18 @@ class ReaderInterface : public ResourceBase {
virtual void Read(QueueInterface* queue, string* key, string* value,
OpKernelContext* context) = 0;
// Read up to num_records records into keys / values. May get more work from
// *queue if the current work is complete. Sets the status on
// *context with an OutOfRange Status if the current work is
// complete and the queue is done (closed and empty).
// This method may block.
// The std::vector keys/value pointers are assumed to point to empty
// structures (that have most likely been reserve(num_records)).
// Returns how many records were actually read.
virtual int64 ReadUpTo(const int64 num_records, QueueInterface* queue,
std::vector<string>* keys, std::vector<string>* value,
OpKernelContext* context) = 0;
// Restore this reader to its newly-constructed state.
virtual Status Reset() = 0;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册