提交 a4806a3f 编写于 作者: V Vijay Vasudevan

TensorFlow: upstream changes to git.

Change 109321497
	Move all images to images directory to make docs versioning easier
	- adjust all paths in the docs to point to the new locations
	- remove some now redundant section-order tags added for the old website
Change 109317807
	Added a kernel op to compute the eigendecomposition of a self-adjoint matrix.

	Added a new kernel op called self_adjoint_eig (and a batch_self_adjoint_eig) that
	computes the eigendecomposition of a self-adjoint matrix. The return value is
	the concatenation of the eigenvalues as a row vector, and the eigenvectors.
Change 109310773
	Change `_read32()` in the MNIST input example to return an int.

	Currently we return a 1-D numpy array with 1 element. Numpy has
	recently deprecated the ability to treat this as a scalar, and as a
	result this tutorial fails. The fix returns the 0th element of the
	array instead.
Change 109301269
	Re-arrange TensorBoard demo files.
Change 109273589
	add ci_build for ci.tensorflow.org
Change 109260293
	Speed up NodeDef -> OpKernel process by not spending time generating
	an error message for missing "_kernel" attr that will be thrown away.
Change 109257179
	TensorFlow:make event_file_loader_test hermetic by using tempfile
	instead of fixed filenames.  Without this change, running
	event_file_loader_test twice in the same client (locally)
	causes it to fail, because it writes into the same file and appends
	another event, instead of starting from scratch.
Change 109256464
	Minor cleanup in TensorBoard server code
Change 109255382
	Change to reduce critical section times in gpu_event_mgr.h:
	(1) Call stream->ThenRecordEvent outside the EventMgr critical section
	(2) Do memory deallocation outside the critical section

	Speeds up one configuration of ptb_word_lm from 2924 words per
	second (wps) to 3278 wps on my desktop machine with a Titan X.
Change 109254843
	Fix use of uninitialized memory in test.
Change 109250995
	python_config.sh needs a license header

	Otherwise the license test fails.
Change 109249914
	add ci_build for ci.tensorflow.org
Change 109249397
	Fixes reduce_sum (complex) on GPU segfaults.

	Fixes #357

Change 109245652
	add ci_build for ci.tensorflow.org

Base CL: 109321563
上级 bb7a7a88
......@@ -2,7 +2,9 @@
# TensorFlow is a computational framework, primarily for use in machine
# learning applications.
package(default_visibility = ["//tensorflow:internal"])
package(
default_visibility = ["//tensorflow:internal"],
)
licenses(["notice"]) # Apache 2.0
......
......@@ -2,7 +2,9 @@
# TensorFlow is a computational framework, primarily for use in machine
# learning applications.
package(default_visibility = ["//tensorflow:internal"])
package(
default_visibility = ["//tensorflow:internal"],
)
package_group(name = "friends")
......
......@@ -40,13 +40,13 @@ EventMgr::~EventMgr() {
delete e;
}
while (!used_events_.empty()) {
InUse* ue = &used_events_[0];
delete ue->event;
delete ue->mem;
if (ue->bufrec.buf) {
ue->bufrec.alloc->DeallocateRaw(ue->bufrec.buf);
delete used_events_[0].event;
delete used_events_[0].mem;
if (used_events_[0].bufrec.buf) {
used_events_[0].bufrec.alloc->DeallocateRaw(used_events_[0].bufrec.buf);
}
if (ue->func != nullptr) threadpool_.Schedule(ue->func);
if (used_events_[0].func != nullptr)
threadpool_.Schedule(used_events_[0].func);
used_events_.pop_front();
}
}
......@@ -60,17 +60,15 @@ EventMgr::~EventMgr() {
void EventMgr::PollLoop() {
while (!stop_polling_.HasBeenNotified()) {
Env::Default()->SleepForMicroseconds(1 * 1000);
ToFreeVector to_free;
{
mutex_lock l(mu_);
PollEvents(true, &to_free);
PollEvents(true);
}
FreeMemory(to_free);
}
polling_stopped_.Notify();
}
void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu, gpu::Event** e) {
void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) {
VLOG(2) << "QueueInUse free_events_ " << free_events_.size()
<< " used_events_ " << used_events_.size();
// Events are created on demand, and repeatedly reused. There is no
......@@ -79,9 +77,10 @@ void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu, gpu::Event** e) {
free_events_.push_back(new gpu::Event(exec_));
free_events_.back()->Init();
}
*e = free_events_.back();
gpu::Event* e = free_events_.back();
free_events_.pop_back();
iu.event = *e;
stream->ThenRecordEvent(e);
iu.event = e;
used_events_.push_back(iu);
}
......@@ -104,8 +103,7 @@ void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu, gpu::Event** e) {
// GPU memory use to spike needlessly. An alternative strategy would
// be to throttle new Op execution until the pending event queue
// clears.
void EventMgr::PollEvents(bool is_dedicated_poller,
gtl::InlinedVector<InUse, 4>* to_free) {
void EventMgr::PollEvents(bool is_dedicated_poller) {
VLOG(2) << "PollEvents free_events_ " << free_events_.size()
<< " used_events_ " << used_events_.size();
// Sweep the remaining events in order. If this is the dedicated
......@@ -125,9 +123,11 @@ void EventMgr::PollEvents(bool is_dedicated_poller,
if (!is_dedicated_poller) return; // quit processing queue
break;
case gpu::Event::Status::kComplete:
// Make a copy of the InUse record so we can free it after releasing
// the lock
to_free->push_back(iu);
delete iu.mem;
if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
// The function must be called in another thread, outside of
// the mutex held here.
if (iu.func != nullptr) threadpool_.Schedule(iu.func);
free_events_.push_back(iu.event);
// Mark this InUse record as completed.
iu.event = nullptr;
......
......@@ -18,10 +18,8 @@ limitations under the License.
#include <deque>
#include <vector>
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/tensor.h"
......@@ -49,15 +47,9 @@ class EventMgr {
// currently enqueued on *stream have completed.
inline void ThenDeleteTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors) {
ToFreeVector to_free;
::perftools::gputools::Event* e;
{
mutex_lock l(mu_);
QueueTensors(stream, tensors, &e);
PollEvents(false, &to_free);
}
stream->ThenRecordEvent(e);
FreeMemory(to_free);
mutex_lock l(mu_);
QueueTensors(stream, tensors);
PollEvents(false);
}
struct BufRec {
......@@ -69,28 +61,16 @@ class EventMgr {
// on it as soon as all events currently enqueued on *stream have completed.
inline void ThenDeleteBuffer(perftools::gputools::Stream* stream,
BufRec bufrec) {
ToFreeVector to_free;
::perftools::gputools::Event* e;
{
mutex_lock l(mu_);
QueueBuffer(stream, bufrec, &e);
PollEvents(false, &to_free);
}
stream->ThenRecordEvent(e);
FreeMemory(to_free);
mutex_lock l(mu_);
QueueBuffer(stream, bufrec);
PollEvents(false);
}
inline void ThenExecute(perftools::gputools::Stream* stream,
std::function<void()> func) {
ToFreeVector to_free;
::perftools::gputools::Event* e;
{
mutex_lock l(mu_);
QueueFunc(stream, func, &e);
PollEvents(false, &to_free);
}
stream->ThenRecordEvent(e);
FreeMemory(to_free);
mutex_lock l(mu_);
QueueFunc(stream, func);
PollEvents(false);
}
private:
......@@ -105,50 +85,32 @@ class EventMgr {
std::function<void()> func;
};
typedef gtl::InlinedVector<InUse, 4> ToFreeVector;
void FreeMemory(const ToFreeVector& to_free) {
for (const auto& iu : to_free) {
delete iu.mem;
if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
// The function must be called in another thread.
if (iu.func != nullptr) threadpool_.Schedule(iu.func);
}
}
// Stream-enqueue an unused Event and save with it a collection of
// Tensors and/or a BufRec to be deleted only after the Event
// records.
void QueueInUse(perftools::gputools::Stream* stream, InUse in_use,
::perftools::gputools::Event** e)
void QueueInUse(perftools::gputools::Stream* stream, InUse in_use)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
void QueueTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors,
::perftools::gputools::Event** e)
std::vector<Tensor>* tensors)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr}, e);
QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr});
}
void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec,
::perftools::gputools::Event** e)
void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr}, e);
QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr});
}
void QueueFunc(perftools::gputools::Stream* stream,
std::function<void()> func, ::perftools::gputools::Event** e)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, nullptr, BufRec(), func}, e);
std::function<void()> func) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, nullptr, BufRec(), func});
}
// This function should be called at roughly the same tempo as
// QueueTensors() to check whether pending events have recorded,
// and then retire them. It appends InUse elements that need cleanup
// to "*to_free". The caller should call FreeMemory(to_free)
// when this returns.
void PollEvents(bool is_dedicated_poller, ToFreeVector* to_free)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
// and then retire them.
void PollEvents(bool is_dedicated_poller) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// An internal polling loop that runs at a low frequency to clear
// straggler Events.
......
......@@ -42,21 +42,13 @@ class TEST_EventMgrHelper {
void QueueTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors) {
::perftools::gputools::Event* e;
{
mutex_lock l(em_->mu_);
em_->QueueTensors(stream, tensors, &e);
}
stream->ThenRecordEvent(e);
mutex_lock l(em_->mu_);
em_->QueueTensors(stream, tensors);
}
void PollEvents(bool is_dedicated_poller) {
EventMgr::ToFreeVector to_free;
{
mutex_lock l(em_->mu_);
em_->PollEvents(is_dedicated_poller, &to_free);
}
em_->FreeMemory(to_free);
mutex_lock l(em_->mu_);
em_->PollEvents(is_dedicated_poller);
}
private:
......
......@@ -79,7 +79,10 @@ Status AttrSlice::Find(const string& attr_name,
return Status::OK();
}
Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
if (ndef_) {
// Skip AttachDef for internal attrs since it is a little bit
// expensive and it is common for them to correctly not be included
// in a NodeDef.
if (!StringPiece(attr_name).starts_with("_") && ndef_) {
s = AttachDef(s, *ndef_);
}
return s;
......
......@@ -46,7 +46,7 @@ class CholeskyOp
const int64 rows = input_matrix_shape.dim_size(0);
if (rows > (1LL << 20)) {
// A big number to cap the cost in case overflow.
return kint32max;
return kint64max;
} else {
return rows * rows * rows;
}
......@@ -69,8 +69,9 @@ class CholeskyOp
// Perform the actual LL^T Cholesky decomposition. This will only use
// the lower triangular part of data_in by default. The upper triangular
// part of the matrix will not be read.
Eigen::LLT<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>> llt_decomposition(input);
Eigen::LLT<
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
llt_decomposition(input);
// Output the lower triangular in a dense form.
*output = llt_decomposition.matrixL();
......
......@@ -44,7 +44,7 @@ class DeterminantOp
const int64 rows = input_matrix_shape.dim_size(0);
if (rows > (1LL << 20)) {
// A big number to cap the cost in case overflow.
return kint32max;
return kint64max;
} else {
return rows * rows * rows;
}
......
......@@ -45,7 +45,7 @@ class MatrixInverseOp
const int64 rows = input_matrix_shape.dim_size(0);
if (rows > (1LL << 20)) {
// A big number to cap the cost in case overflow.
return kint32max;
return kint64max;
} else {
return rows * rows * rows;
}
......
......@@ -44,7 +44,10 @@ REGISTER_GPU_KERNELS(float);
#undef REGISTER_GPU_KERNELS
REGISTER_KERNEL_BUILDER(
Name("Sum").Device(DEVICE_GPU).TypeConstraint<complex64>("T"),
Name("Sum")
.Device(DEVICE_GPU)
.TypeConstraint<complex64>("T")
.HostMemory("reduction_indices"),
ReductionOp<GPUDevice, complex64, Eigen::internal::SumReducer<complex64>>);
#endif
......
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
#include <cmath>
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/Eigen/Eigenvalues"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/tensor_shape.h"
namespace tensorflow {
template <class Scalar, bool SupportsBatchOperationT>
class SelfAdjointEigOp
: public UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
public:
explicit SelfAdjointEigOp(OpKernelConstruction* context)
: UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
TensorShape GetOutputMatrixShape(
const TensorShape& input_matrix_shape) override {
int64 d = input_matrix_shape.dim_size(0);
return TensorShape({d + 1, d});
}
int64 GetCostPerUnit(const TensorShape& input_matrix_shape) override {
const int64 rows = input_matrix_shape.dim_size(0);
if (rows > (1LL << 20)) {
// A big number to cap the cost in case overflow.
return kint64max;
} else {
return rows * rows * rows;
}
}
using
typename UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
using typename UnaryLinearAlgebraOp<Scalar,
SupportsBatchOperationT>::ConstMatrixMap;
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input,
MatrixMap* output) override {
OP_REQUIRES(context, input.rows() == input.cols(),
errors::InvalidArgument("Input matrix must be square."));
if (input.rows() == 0) {
// If X is an empty matrix (0 rows, 0 col), X * X' == X.
// Therefore, we return X.
return;
}
Eigen::SelfAdjointEigenSolver<
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
es(input);
output->row(0) = es.eigenvalues().transpose();
output->bottomRows(input.rows()) = es.eigenvectors();
OP_REQUIRES(context, es.info() == Eigen::Success,
errors::InvalidArgument("Self Adjoint Eigen decomposition was"
"not successful. "
"The input might not be valid."));
}
};
REGISTER_LINALG_OP("SelfAdjointEig", (SelfAdjointEigOp<float, false>), float);
REGISTER_LINALG_OP("SelfAdjointEig", (SelfAdjointEigOp<double, false>), double);
REGISTER_LINALG_OP("BatchSelfAdjointEig", (SelfAdjointEigOp<float, true>),
float);
REGISTER_LINALG_OP("BatchSelfAdjointEig", (SelfAdjointEigOp<double, true>),
double);
} // namespace tensorflow
......@@ -326,7 +326,7 @@ If `indices` is a permutation and `len(indices) == params.shape[0]` then
this operation will permute `params` accordingly.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/Gather.png" alt>
<img style="width:100%" src="../../images/Gather.png" alt>
</div>
)doc");
......
......@@ -57,7 +57,7 @@ For example:
outputs[1] = [30, 40]
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/DynamicPartition.png" alt>
<img style="width:100%" src="../../images/DynamicPartition.png" alt>
</div>
partitions: Any shape. Indices in the range `[0, num_partitions)`.
......@@ -108,7 +108,7 @@ For example:
[51, 52], [61, 62]]
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/DynamicStitch.png" alt>
<img style="width:100%" src="../../images/DynamicStitch.png" alt>
</div>
)doc");
......
......@@ -123,4 +123,41 @@ output: Shape is `[..., M, M]`.
T: The type of values in the input and output.
)doc");
REGISTER_OP("SelfAdjointEig")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float}")
.Doc(R"doc(
Calculates the Eigen Decomposition of a square Self-Adjoint matrix.
Only the lower-triangular part of the input will be used in this case. The
upper-triangular part will not be read.
The result is a M+1 x M matrix whose first row is the eigenvalues, and
subsequent rows are eigenvectors.
input: Shape is `[M, M]`.
output: Shape is `[M+1, M]`.
T: The type of values in the input and output.
)doc");
REGISTER_OP("BatchSelfAdjointEig")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float}")
.Doc(R"doc(
Calculates the Eigen Decomposition of a batch of square self-adjoint matrices.
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices, with the same constraints as the single matrix
SelfAdjointEig.
The result is a '[..., M+1, M] matrix with [..., 0,:] containing the
eigenvalues, and subsequent [...,1:, :] containing the eigenvectors.
input: Shape is `[..., M, M]`.
output: Shape is `[..., M+1, M]`.
T: The type of values in the input and output.
)doc");
} // namespace tensorflow
......@@ -649,7 +649,7 @@ Computes a tensor such that
that `segment_ids[j] == i`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentSum.png" alt>
<img style="width:100%" src="../../images/SegmentSum.png" alt>
</div>
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
......@@ -678,7 +678,7 @@ over `j` such that `segment_ids[j] == i` and `N` is the total number of
values summed.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentMean.png" alt>
<img style="width:100%" src="../../images/SegmentMean.png" alt>
</div>
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
......@@ -706,7 +706,7 @@ Computes a tensor such that
that `segment_ids[j] == i`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentProd.png" alt>
<img style="width:100%" src="../../images/SegmentProd.png" alt>
</div>
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
......@@ -734,7 +734,7 @@ Computes a tensor such that
that `segment_ids[j] == i`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentMin.png" alt>
<img style="width:100%" src="../../images/SegmentMin.png" alt>
</div>
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
......@@ -761,7 +761,7 @@ Computes a tensor such that
that `segment_ids[j] == i`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentMax.png" alt>
<img style="width:100%" src="../../images/SegmentMax.png" alt>
</div>
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
......@@ -796,7 +796,7 @@ If the sum is empty for a given segment ID `i`, `output[i] = 0`.
`num_segments` should equal the number of distinct segment IDs.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/UnsortedSegmentSum.png" alt>
<img style="width:100%" src="../../images/UnsortedSegmentSum.png" alt>
</div>
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
......
......@@ -1224,6 +1224,32 @@ op {
}
summary: "Gradients for batch normalization."
}
op {
name: "BatchSelfAdjointEig"
input_arg {
name: "input"
description: "Shape is `[..., M, M]`."
type_attr: "T"
}
output_arg {
name: "output"
description: "Shape is `[..., M+1, M]`."
type_attr: "T"
}
attr {
name: "T"
type: "type"
description: "The type of values in the input and output."
allowed_values {
list {
type: DT_DOUBLE
type: DT_FLOAT
}
}
}
summary: "Calculates the Eigen Decomposition of a batch of square self-adjoint matrices."
description: "The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions\nform square matrices, with the same constraints as the single matrix\nSelfAdjointEig.\n\nThe result is a \'[..., M+1, M] matrix with [..., 0,:] containing the\neigenvalues, and subsequent [...,1:, :] containing the eigenvectors."
}
op {
name: "BiasAdd"
input_arg {
......@@ -6519,6 +6545,32 @@ op {
summary: "Selects elements from `t` or `e`, depending on `condition`."
description: "The `condition`, `t`, and `e` tensors must all have the same shape,\nand the output will also have that shape. The `condition` tensor acts\nas an element-wise mask that chooses, based on the value at each\nelement, whether the corresponding element in the output should be\ntaken from `t` (if true) or `e` (if false). For example:\n\nFor example:\n\n```prettyprint\n# \'condition\' tensor is [[True, False]\n# [True, False]]\n# \'t\' is [[1, 1],\n# [1, 1]]\n# \'e\' is [[2, 2],\n# [2, 2]]\nselect(condition, t, e) ==> [[1, 2],\n [1, 2]]\n```"
}
op {
name: "SelfAdjointEig"
input_arg {
name: "input"
description: "Shape is `[M, M]`."
type_attr: "T"
}
output_arg {
name: "output"
description: "Shape is `[M+1, M]`."
type_attr: "T"
}
attr {
name: "T"
type: "type"
description: "The type of values in the input and output."
allowed_values {
list {
type: DT_DOUBLE
type: DT_FLOAT
}
}
}
summary: "Calculates the Eigen Decomposition of a square Self-Adjoint matrix."
description: "Only the lower-triangular part of the input will be used in this case. The\nupper-triangular part will not be read.\n\nThe result is a M+1 x M matrix whose first row is the eigenvalues, and\nsubsequent rows are eigenvectors."
}
op {
name: "Shape"
input_arg {
......
......@@ -188,7 +188,7 @@ override earlier entries.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/ScatterUpdate.png" alt>
<img style="width:100%" src="../../images/ScatterUpdate.png" alt>
</div>
ref: Should be from a `Variable` node.
......@@ -231,7 +231,7 @@ the same location, their contributions add.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/ScatterAdd.png" alt>
<img style="width:100%" src="../../images/ScatterAdd.png" alt>
</div>
ref: Should be from a `Variable` node.
......@@ -272,7 +272,7 @@ the same location, their (negated) contributions add.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/ScatterSub.png" alt>
<img style="width:100%" src="../../images/ScatterSub.png" alt>
</div>
ref: Should be from a `Variable` node.
......
# Description:
# Tensorflow camera demo app for Android.
package(default_visibility = ["//visibility:public"])
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
......
......@@ -904,7 +904,7 @@ If `indices` is a permutation and `len(indices) == params.shape[0]` then
this operation will permute `params` accordingly.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/Gather.png" alt>
<img style="width:100%" src="../../images/Gather.png" alt>
</div>
##### Args:
......@@ -954,7 +954,7 @@ For example:
outputs[1] = [30, 40]
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/DynamicPartition.png" alt>
<img style="width:100%" src="../../images/DynamicPartition.png" alt>
</div>
##### Args:
......@@ -1013,7 +1013,7 @@ For example:
[51, 52], [61, 62]]
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/DynamicStitch.png" alt>
<img style="width:100%" src="../../images/DynamicStitch.png" alt>
</div>
##### Args:
......
......@@ -192,7 +192,7 @@ The convenience function [`resize_images()`](#resize_images) supports both 4-D
and 3-D tensors as input and output. 4-D tensors are for batches of images,
3-D tensors for individual images.
Other resizing Ops only support 3-D individual images as input:
Other resizing Ops only support 4-D batches of images as input:
[`resize_area`](#resize_area), [`resize_bicubic`](#resize_bicubic),
[`resize_bilinear`](#resize_bilinear),
[`resize_nearest_neighbor`](#resize_nearest_neighbor).
......@@ -200,9 +200,9 @@ Other resizing Ops only support 3-D individual images as input:
Example:
```python
# Decode a JPG image and resize it to 299 by 299.
# Decode a JPG image and resize it to 299 by 299 using default method.
image = tf.image.decode_jpeg(...)
resized_image = tf.image.resize_bilinear(image, [299, 299])
resized_image = tf.image.resize_images(image, 299, 299)
```
- - -
......
......@@ -68,6 +68,7 @@
* [`uniform_unit_scaling_initializer`](../../api_docs/python/state_ops.md#uniform_unit_scaling_initializer)
* [`update_checkpoint_state`](../../api_docs/python/state_ops.md#update_checkpoint_state)
* [`Variable`](../../api_docs/python/state_ops.md#Variable)
* [`variable_op_scope`](../../api_docs/python/state_ops.md#variable_op_scope)
* [`variable_scope`](../../api_docs/python/state_ops.md#variable_scope)
* [`zeros_initializer`](../../api_docs/python/state_ops.md#zeros_initializer)
......@@ -356,20 +357,3 @@
* [`write_graph`](../../api_docs/python/train.md#write_graph)
* [`zero_fraction`](../../api_docs/python/train.md#zero_fraction)
<div class="sections-order" style="display: none;">
<!--
<!-- framework.md -->
<!-- constant_op.md -->
<!-- state_ops.md -->
<!-- array_ops.md -->
<!-- math_ops.md -->
<!-- control_flow_ops.md -->
<!-- image.md -->
<!-- sparse_ops.md -->
<!-- io_ops.md -->
<!-- python_io.md -->
<!-- nn.md -->
<!-- client.md -->
<!-- train.md -->
-->
</div>
......@@ -1340,7 +1340,7 @@ Computes a tensor such that
that `segment_ids[j] == i`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentSum.png" alt>
<img style="width:100%" src="../../images/SegmentSum.png" alt>
</div>
##### Args:
......@@ -1374,7 +1374,7 @@ Computes a tensor such that
that `segment_ids[j] == i`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentProd.png" alt>
<img style="width:100%" src="../../images/SegmentProd.png" alt>
</div>
##### Args:
......@@ -1408,7 +1408,7 @@ Computes a tensor such that
that `segment_ids[j] == i`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentMin.png" alt>
<img style="width:100%" src="../../images/SegmentMin.png" alt>
</div>
##### Args:
......@@ -1441,7 +1441,7 @@ Computes a tensor such that
that `segment_ids[j] == i`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentMax.png" alt>
<img style="width:100%" src="../../images/SegmentMax.png" alt>
</div>
##### Args:
......@@ -1476,7 +1476,7 @@ over `j` such that `segment_ids[j] == i` and `N` is the total number of
values summed.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/SegmentMean.png" alt>
<img style="width:100%" src="../../images/SegmentMean.png" alt>
</div>
##### Args:
......@@ -1517,7 +1517,7 @@ If the sum is empty for a given segment ID `i`, `output[i] = 0`.
`num_segments` should equal the number of distinct segment IDs.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/UnsortedSegmentSum.png" alt>
<img style="width:100%" src="../../images/UnsortedSegmentSum.png" alt>
</div>
##### Args:
......
......@@ -59,7 +59,7 @@ Computes Rectified Linear 6: `min(max(features, 0), 6)`.
Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise.
See [Fast and Aaccurate Deep Network Learning by Exponential Linear Units (ELUs)
See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
](http://arxiv.org/abs/1511.07289)
##### Args:
......@@ -872,7 +872,7 @@ See [Noise-contrastive estimation: A new estimation principle for
unnormalized statistical models]
(http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
Also see our [Candidate Sampling Algorithms Reference]
(http://www.tensorflow.org/extras/candidate_sampling.pdf)
(../../extras/candidate_sampling.pdf)
Note: In the case where `num_true` > 1, we assign to each target class
the target probability 1 / `num_true` so that the target probabilities
......@@ -906,7 +906,7 @@ with an otherwise unused class.
`True`, this is a "Sampled Logistic" loss instead of NCE, and we are
learning to generate log-odds instead of log probabilities. See
our [Candidate Sampling Algorithms Reference]
(http://www.tensorflow.org/extras/candidate_sampling.pdf).
(../../extras/candidate_sampling.pdf).
Default is False.
* <b>`name`</b>: A name for the operation (optional).
......@@ -931,7 +931,7 @@ At inference time, you can compute full softmax probabilities with the
expression `tf.nn.softmax(tf.matmul(inputs, weights) + biases)`.
See our [Candidate Sampling Algorithms Reference]
(http://www.tensorflow.org/extras/candidate_sampling.pdf)
(../../extras/candidate_sampling.pdf)
Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
......
......@@ -913,6 +913,58 @@ the constructor is used. If that one is `None` too, a
Returns the current variable scope.
- - -
### `tf.variable_op_scope(values, name, default_name, initializer=None)` {#variable_op_scope}
Returns a context manager for defining an op that creates variables.
This context manager validates that the given `values` are from the
same graph, ensures that that graph is the default graph, and pushes a
name scope and a variable scope.
If `name` is not None, it is used as is in the variable scope. If `name`
is None, then `default_name` is used. In that case, if the same name has been
previously used in the same scope, it will made unique be appending `_N` to
it.
This is intended to be used when defining generic ops and so reuse is always
inherited.
For example, to define a new Python op called `my_op_with_vars`:
```python
def my_op_with_vars(a, b, name=None):
with tf.variable_op_scope([a, b], name, "MyOp") as scope:
a = tf.convert_to_tensor(a, name="a")
b = tf.convert_to_tensor(b, name="b")
c = tf.get_variable('c')
# Define some computation that uses `a`, `b`, and `c`.
return foo_op(..., name=scope)
```
##### Args:
* <b>`values`</b>: The list of `Tensor` arguments that are passed to the op function.
* <b>`name`</b>: The name argument that is passed to the op function, this name is not
uniquified in the variable scope.
* <b>`default_name`</b>: The default name to use if the `name` argument is `None`, this
name will be uniquified.
* <b>`initializer`</b>: A default initializer to pass to variable scope.
##### Returns:
A context manager for use in defining a Python op.
##### Raises:
* <b>`ValueError`</b>: when trying to reuse within a create scope, or create within
a reuse scope, or if reuse is not `None` or `True`.
* <b>`TypeError`</b>: when the types of some arguments are not appropriate.
- - -
### `tf.variable_scope(name_or_scope, reuse=None, initializer=None)` {#variable_scope}
......@@ -983,7 +1035,7 @@ then all its sub-scopes become reusing as well.
well as all sub-scopes; if `None`, we just inherit the parent scope reuse.
* <b>`initializer`</b>: default initializer for variables within this scope.
##### Yields:
##### Returns:
A scope that can be to captured and reused.
......@@ -1167,7 +1219,7 @@ override earlier entries.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/ScatterUpdate.png" alt>
<img style="width:100%" src="../../images/ScatterUpdate.png" alt>
</div>
##### Args:
......@@ -1215,7 +1267,7 @@ the same location, their contributions add.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/ScatterAdd.png" alt>
<img style="width:100%" src="../../images/ScatterAdd.png" alt>
</div>
##### Args:
......@@ -1262,7 +1314,7 @@ the same location, their (negated) contributions add.
Requires `updates.shape = indices.shape + ref.shape[1:]`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/ScatterSub.png" alt>
<img style="width:100%" src="../../images/ScatterSub.png" alt>
</div>
##### Args:
......
......@@ -60,10 +60,10 @@ suggest skimming blue, then red.
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px; display: flex; flex-direction: row">
<a href="../tutorials/mnist/beginners/index.md" title="MNIST for ML Beginners tutorial">
<img style="flex-grow:1; flex-shrink:1; border: 1px solid black;" src="blue_pill.png" alt="MNIST for machine learning beginners tutorial" />
<img style="flex-grow:1; flex-shrink:1; border: 1px solid black;" src="../images/blue_pill.png" alt="MNIST for machine learning beginners tutorial" />
</a>
<a href="../tutorials/mnist/pros/index.md" title="Deep MNIST for ML Experts tutorial">
<img style="flex-grow:1; flex-shrink:1; border: 1px solid black;" src="red_pill.png" alt="Deep MNIST for machine learning experts tutorial" />
<img style="flex-grow:1; flex-shrink:1; border: 1px solid black;" src="../images/red_pill.png" alt="Deep MNIST for machine learning experts tutorial" />
</a>
</div>
<p style="font-size:10px;">Images licensed CC BY-SA 4.0; original by W. Carter</p>
......@@ -77,12 +77,3 @@ TensorFlow features.
* [Download and Setup](../get_started/os_setup.md)
* [Basic Usage](../get_started/basic_usage.md)
* [TensorFlow Mechanics 101](../tutorials/mnist/tf/index.md)
<div class='sections-order' style="display: none;">
<!--
<!-- os_setup.md -->
<!-- basic_usage.md -->
-->
</div>
......@@ -2,7 +2,7 @@
TensorFlow computation graphs are powerful but complicated. The graph visualization can help you understand and debug them. Here's an example of the visualization at work.
![Visualization of a TensorFlow graph](./graph_vis_animation.gif "Visualization of a TensorFlow graph")
![Visualization of a TensorFlow graph](../../images/graph_vis_animation.gif "Visualization of a TensorFlow graph")
*Visualization of a TensorFlow graph.*
To see your own graph, run TensorBoard pointing it to the log directory of the job, click on the graph tab on the top pane and select the appropriate run using the menu at the upper left corner. For in depth information on how to run TensorBoard and make sure you are logging all the necessary information, see [TensorBoard: Visualizing Learning](../../how_tos/summaries_and_tensorboard/index.md).
......@@ -43,10 +43,10 @@ expanded states.
<table width="100%;">
<tr>
<td style="width: 50%;">
<img src="./pool1_collapsed.png" alt="Unexpanded name scope" title="Unexpanded name scope" />
<img src="../../images/pool1_collapsed.png" alt="Unexpanded name scope" title="Unexpanded name scope" />
</td>
<td style="width: 50%;">
<img src="./pool1_expanded.png" alt="Expanded name scope" title="Expanded name scope" />
<img src="../../images/pool1_expanded.png" alt="Expanded name scope" title="Expanded name scope" />
</td>
</tr>
<tr>
......@@ -86,10 +86,10 @@ information since these nodes are usually related to bookkeeping functions.
<table width="100%;">
<tr>
<td style="width: 50%;">
<img src="./conv_1.png" alt="conv_1 is part of the main graph" title="conv_1 is part of the main graph" />
<img src="../../images/conv_1.png" alt="conv_1 is part of the main graph" title="conv_1 is part of the main graph" />
</td>
<td style="width: 50%;">
<img src="./save.png" alt="save is extracted as auxiliary node" title="save is extracted as auxiliary node" />
<img src="../../images/save.png" alt="save is extracted as auxiliary node" title="save is extracted as auxiliary node" />
</td>
</tr>
<tr>
......@@ -111,10 +111,10 @@ with hierarchical nodes, double-clicking expands the series.
<table width="100%;">
<tr>
<td style="width: 50%;">
<img src="./series.png" alt="Sequence of nodes" title="Sequence of nodes" />
<img src="../../images/series.png" alt="Sequence of nodes" title="Sequence of nodes" />
</td>
<td style="width: 50%;">
<img src="./series_expanded.png" alt="Expanded sequence of nodes" title="Expanded sequence of nodes" />
<img src="../../images/series_expanded.png" alt="Expanded sequence of nodes" title="Expanded sequence of nodes" />
</td>
</tr>
<tr>
......@@ -132,15 +132,15 @@ for constants and summary nodes. To summarize, here's a table of node symbols:
Symbol | Meaning
--- | ---
![Name scope](./namespace_node.png "Name scope") | *High-level* node representing a name scope. Double-click to expand a high-level node.
![Sequence of unconnected nodes](./horizontal_stack.png "Sequence of unconnected nodes") | Sequence of numbered nodes that are not connected to each other.
![Sequence of connected nodes](./vertical_stack.png "Sequence of connected nodes") | Sequence of numbered nodes that are connected to each other.
![Operation node](./op_node.png "Operation node") | An individual operation node.
![Constant node](./constant.png "Constant node") | A constant.
![Summary node](./summary.png "Summary node") | A summary node.
![Data flow edge](./dataflow_edge.png "Data flow edge") | Edge showing the data flow between operations.
![Control dependency edge](./control_edge.png "Control dependency edge") | Edge showing the control dependency between operations.
![Reference edge](./reference_edge.png "Reference edge") | A reference edge showing that the outgoing operation node can mutate the incoming tensor.
![Name scope](../../images/namespace_node.png "Name scope") | *High-level* node representing a name scope. Double-click to expand a high-level node.
![Sequence of unconnected nodes](../../images/horizontal_stack.png "Sequence of unconnected nodes") | Sequence of numbered nodes that are not connected to each other.
![Sequence of connected nodes](../../images/vertical_stack.png "Sequence of connected nodes") | Sequence of numbered nodes that are connected to each other.
![Operation node](../../images/op_node.png "Operation node") | An individual operation node.
![Constant node](../../images/constant.png "Constant node") | A constant.
![Summary node](../../images/summary.png "Summary node") | A summary node.
![Data flow edge](../../images/dataflow_edge.png "Data flow edge") | Edge showing the data flow between operations.
![Control dependency edge](../../images/control_edge.png "Control dependency edge") | Edge showing the control dependency between operations.
![Reference edge](../../images/reference_edge.png "Reference edge") | A reference edge showing that the outgoing operation node can mutate the incoming tensor.
## Interaction
......@@ -158,10 +158,10 @@ right corner of the visualization.
<table width="100%;">
<tr>
<td style="width: 50%;">
<img src="./infocard.png" alt="Info card of a name scope" title="Info card of a name scope" />
<img src="../../images/infocard.png" alt="Info card of a name scope" title="Info card of a name scope" />
</td>
<td style="width: 50%;">
<img src="./infocard_op.png" alt="Info card of operation node" title="Info card of operation node" />
<img src="../../images/infocard_op.png" alt="Info card of operation node" title="Info card of operation node" />
</td>
</tr>
<tr>
......@@ -194,10 +194,10 @@ The images below give an illustration for a piece of a real-life graph.
<table width="100%;">
<tr>
<td style="width: 50%;">
<img src="./colorby_structure.png" alt="Color by structure" title="Color by structure" />
<img src="../../images/colorby_structure.png" alt="Color by structure" title="Color by structure" />
</td>
<td style="width: 50%;">
<img src="./colorby_device.png" alt="Color by device" title="Color by device" />
<img src="../../images/colorby_device.png" alt="Color by device" title="Color by device" />
</td>
</tr>
<tr>
......
......@@ -84,19 +84,3 @@ different locations in the model construction code.
The "Variable Scope" mechanism is designed to facilitate that.
[View Tutorial](../how_tos/variable_scope/index.md)
<div class='sections-order' style="display: none;">
<!--
<!-- variables/index.md -->
<!-- ../tutorials/mnist/tf/index.md -->
<!-- summaries_and_tensorboard/index.md -->
<!-- graph_viz/index.md -->
<!-- reading_data/index.md -->
<!-- threading_and_queues/index.md -->
<!-- adding_an_op/index.md -->
<!-- new_data_formats/index.md -->
<!-- using_gpu/index.md -->
<!-- variable_scope/index.md -->
-->
</div>
......@@ -311,7 +311,7 @@ operations, so that our training loop can dequeue examples from the example
queue.
<div style="width:70%; margin-left:12%; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="AnimatedFileQueues.gif">
<img style="width:100%" src="../../images/AnimatedFileQueues.gif">
</div>
The helpers in `tf.train` that create these queues and enqueuing operations add
......
......@@ -8,7 +8,7 @@ your TensorFlow graph, plot quantitative metrics about the execution of your
graph, and show additional data like images that pass through it. When
TensorBoard is fully configured, it looks like this:
![MNIST TensorBoard](./mnist_tensorboard.png "MNIST TensorBoard")
![MNIST TensorBoard](../../images/mnist_tensorboard.png "MNIST TensorBoard")
## Serializing the data
......
......@@ -14,7 +14,7 @@ that takes an item off the queue, adds one to that item, and puts it back on the
end of the queue. Slowly, the numbers on the queue increase.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="IncremeterFifoQueue.gif">
<img style="width:100%" src="../../images/IncremeterFifoQueue.gif">
</div>
`Enqueue`, `EnqueueMany`, and `Dequeue` are special nodes. They take a pointer
......
......@@ -55,14 +55,3 @@ https://github.com/tensorflow/tensorflow/issues) on GitHub.
If you need help with using TensorFlow, please do not use the issue
tracker for that. Instead, direct your questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).
<div class='sections-order' style="display: none;">
<!--
<!-- bib.md -->
<!-- uses.md -->
<!-- faq.md -->
<!-- glossary.md -->
<!-- dims_types.md -->
-->
</div>
......@@ -9,7 +9,7 @@ CIFAR-10 classification is a common benchmark problem in machine learning. The
problem is to classify RGB 32x32 pixel images across 10 categories:
```airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.```
![CIFAR-10 Samples](./cifar_samples.png "CIFAR-10 Samples, from http://www.cs.toronto.edu/~kriz/cifar.html")
![CIFAR-10 Samples](../../images/cifar_samples.png "CIFAR-10 Samples, from http://www.cs.toronto.edu/~kriz/cifar.html")
For more details refer to the [CIFAR-10 page](http://www.cs.toronto.edu/~kriz/cifar.html)
and a [Tech Report](http://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf)
......@@ -135,7 +135,7 @@ so that we may visualize them in TensorBoard. This is a good practice to verify
that inputs are built correctly.
<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:70%" src="./cifar_image_summary.png">
<img style="width:70%" src="../../images/cifar_image_summary.png">
</div>
Reading images from disk and distorting them can use a non-trivial amount of
......@@ -164,7 +164,7 @@ Layer Name | Description
Here is a graph generated from TensorBoard describing the inference operation:
<div style="width:15%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="./cifar_graph.png">
<img style="width:100%" src="../../images/cifar_graph.png">
</div>
> **EXERCISE**: The output of `inference` are un-normalized logits. Try editing
......@@ -199,7 +199,7 @@ loss and all these weight decay terms, as returned by the `loss()` function.
We visualize it in TensorBoard with a [`scalar_summary`](../../api_docs/python/train.md#scalar_summary):
![CIFAR-10 Loss](./cifar_loss.png "CIFAR-10 Total Loss")
![CIFAR-10 Loss](../../images/cifar_loss.png "CIFAR-10 Total Loss")
We train the model using standard
[gradient descent](https://en.wikipedia.org/wiki/Gradient_descent)
......@@ -208,7 +208,7 @@ with a learning rate that
[exponentially decays](../../api_docs/python/train.md#exponential_decay)
over time.
![CIFAR-10 Learning Rate Decay](./cifar_lr_decay.png "CIFAR-10 Learning Rate Decay")
![CIFAR-10 Learning Rate Decay](../../images/cifar_lr_decay.png "CIFAR-10 Learning Rate Decay")
The `train()` function adds the operations needed to minimize the objective by
calculating the gradient and updating the learned variables (see
......@@ -289,8 +289,8 @@ For instance, we can watch how the distribution of activations and degree of
sparsity in `local3` features evolve during training:
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px; display: flex; flex-direction: row">
<img style="flex-grow:1; flex-shrink:1;" src="./cifar_sparsity.png">
<img style="flex-grow:1; flex-shrink:1;" src="./cifar_activations.png">
<img style="flex-grow:1; flex-shrink:1;" src="../../images/cifar_sparsity.png">
<img style="flex-grow:1; flex-shrink:1;" src="../../images/cifar_activations.png">
</div>
Individual loss functions, as well as the total loss, are particularly
......@@ -372,7 +372,7 @@ processing a batch of data.
Here is a diagram of this model:
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="./Parallelism.png">
<img style="width:100%" src="../../images/Parallelism.png">
</div>
Note that each GPU computes inference as well as the gradients for a unique
......
......@@ -107,19 +107,3 @@ visual hallucination software.
COMING SOON
<div class='sections-order' style="display: none;">
<!--
<!-- mnist/beginners/index.md -->
<!-- mnist/pros/index.md -->
<!-- mnist/tf/index.md -->
<!-- deep_cnn/index.md -->
<!-- word2vec/index.md -->
<!-- recurrent/index.md -->
<!-- seq2seq/index.md -->
<!-- mandelbrot/index.md -->
<!-- pdes/index.md -->
<!-- mnist/download/index.md -->
-->
</div>
......@@ -110,7 +110,7 @@ Let's see what we've got.
DisplayFractal(ns.eval())
```
![jpeg](mandelbrot_output.jpg)
![jpeg](../../images/mandelbrot_output.jpg)
Not bad!
......
......@@ -3,7 +3,7 @@
*This tutorial is intended for readers who are new to both machine learning and
TensorFlow. If you already
know what MNIST is, and what softmax (multinomial logistic) regression is,
you might prefer this [faster paced tutorial](../../../tutorials/mnist/pros/index.md).*
you might prefer this [faster paced tutorial](../pros/index.md).*
When one learns how to program, there's a tradition that the first thing you do
is print "Hello World." Just like programming has Hello World, machine learning
......@@ -13,7 +13,7 @@ MNIST is a simple computer vision dataset. It consists of images of handwritten
digits like these:
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/MNIST.png">
<img style="width:100%" src="../../../images/MNIST.png">
</div>
It also includes labels for each image, telling us which digit it is. For
......@@ -61,7 +61,7 @@ Each image is 28 pixels by 28 pixels. We can interpret this as a big array of
numbers:
<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/MNIST-Matrix.png">
<img style="width:100%" src="../../../images/MNIST-Matrix.png">
</div>
We can flatten this array into a vector of 28x28 = 784 numbers. It doesn't
......@@ -83,7 +83,7 @@ the pixel intensity between 0 and 1, for a particular pixel in a particular
image.
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/mnist-train-xs.png">
<img style="width:100%" src="../../../images/mnist-train-xs.png">
</div>
The corresponding labels in MNIST are numbers between 0 and 9, describing
......@@ -97,7 +97,7 @@ Consequently, `mnist.train.labels` is a
`[55000, 10]` array of floats.
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/mnist-train-ys.png">
<img style="width:100%" src="../../../images/mnist-train-ys.png">
</div>
We're now ready to actually make our model!
......@@ -128,7 +128,7 @@ classes. Red represents negative weights, while blue represents positive
weights.
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/softmax-weights.png">
<img style="width:100%" src="../../../images/softmax-weights.png">
</div>
We also add some extra evidence called a bias. Basically, we want to be able
......@@ -175,13 +175,13 @@ although with a lot more \\(x\\)s. For each output, we compute a weighted sum of
the \\(x\\)s, add a bias, and then apply softmax.
<div style="width:55%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/softmax-regression-scalargraph.png">
<img style="width:100%" src="../../../images/softmax-regression-scalargraph.png">
</div>
If we write that out as equations, we get:
<div style="width:52%; margin-left:25%; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/softmax-regression-scalarequation.png">
<img style="width:100%" src="../../../images/softmax-regression-scalarequation.png">
</div>
We can "vectorize" this procedure, turning it into a matrix multiplication
......@@ -189,7 +189,7 @@ and vector addition. This is helpful for computational efficiency. (It's also
a useful way to think.)
<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/softmax-regression-vectorequation.png">
<img style="width:100%" src="../../../images/softmax-regression-vectorequation.png">
</div>
More compactly, we can just write:
......
......@@ -19,7 +19,7 @@ MNIST is a classic problem in machine learning. The problem is to look at
greyscale 28x28 pixel images of handwritten digits and determine which digit
the image represents, for all the digits from zero to nine.
![MNIST Digits](../tf/mnist_digits.png "MNIST Digits")
![MNIST Digits](../../../images/mnist_digits.png "MNIST Digits")
For more information, refer to [Yann LeCun's MNIST page](http://yann.lecun.com/exdb/mnist/)
or [Chris Olah's visualizations of MNIST](http://colah.github.io/posts/2014-10-Visualizing-MNIST/).
......
......@@ -42,7 +42,7 @@ def maybe_download(filename, work_directory):
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def extract_images(filename):
......
......@@ -9,7 +9,7 @@ while constructing a deep convolutional MNIST classifier.
*This introduction assumes familiarity with neural networks and the MNIST
dataset. If you don't have
a background with them, check out the
[introduction for beginners](../../../tutorials/mnist/beginners/index.md).*
[introduction for beginners](../beginners/index.md).*
## Setup
......
......@@ -31,7 +31,7 @@ MNIST is a classic problem in machine learning. The problem is to look at
greyscale 28x28 pixel images of handwritten digits and determine which digit
the image represents, for all the digits from zero to nine.
![MNIST Digits](./mnist_digits.png "MNIST Digits")
![MNIST Digits](../../../images/mnist_digits.png "MNIST Digits")
For more information, refer to [Yann LeCun's MNIST page](http://yann.lecun.com/exdb/mnist/)
or [Chris Olah's visualizations of MNIST](http://colah.github.io/posts/2014-10-Visualizing-MNIST/).
......@@ -90,7 +90,7 @@ loss.
and apply gradients.
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="./mnist_subgraph.png">
<img style="width:100%" src="../../../images/mnist_subgraph.png">
</div>
### Inference
......@@ -401,7 +401,7 @@ summary_writer.add_summary(summary_str, step)
When the events files are written, TensorBoard may be run against the training
folder to display the values from the summaries.
![MNIST TensorBoard](./mnist_tensorboard.png "MNIST TensorBoard")
![MNIST TensorBoard](../../../images/mnist_tensorboard.png "MNIST TensorBoard")
**NOTE**: For more info about how to build and run Tensorboard, please see the accompanying tutorial [Tensorboard: Visualizing Your Training](../../../how_tos/summaries_and_tensorboard/index.md).
......
......@@ -92,7 +92,7 @@ for n in range(40):
DisplayArray(u_init, rng=[-0.1, 0.1])
```
![jpeg](pde_output_1.jpg)
![jpeg](../../images/pde_output_1.jpg)
Now let's specify the details of the differential equation.
......@@ -137,7 +137,7 @@ for i in range(1000):
DisplayArray(U.eval(), rng=[-0.1, 0.1])
```
![jpeg](pde_output_2.jpg)
![jpeg](../../images/pde_output_2.jpg)
Look! Ripples!
......@@ -41,7 +41,7 @@ processes the input and a *decoder* that generates the output.
This basic architecture is depicted below.
<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="basic_seq2seq.png" />
<img style="width:100%" src="../../images/basic_seq2seq.png" />
</div>
Each box in the picture above represents a cell of the RNN, most commonly
......@@ -61,7 +61,7 @@ decoding step. A multi-layer sequence-to-sequence network with LSTM cells and
attention mechanism in the decoder looks like this.
<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="attention_seq2seq.png" />
<img style="width:100%" src="../../images/attention_seq2seq.png" />
</div>
## TensorFlow seq2seq Library
......
......@@ -51,7 +51,7 @@ means that we may need more data in order to successfully train statistical
models. Using vector representations can overcome some of these obstacles.
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/audio-image-text.png" alt>
<img style="width:100%" src="../../images/audio-image-text.png" alt>
</div>
[Vector space models](https://en.wikipedia.org/wiki/Vector_space_model) (VSMs)
......@@ -124,7 +124,7 @@ probability using the score for all other \\(V\\) words \\(w'\\) in the current
context \\(h\\), *at every training step*.
<div style="width:60%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/softmax-nplm.png" alt>
<img style="width:100%" src="../../images/softmax-nplm.png" alt>
</div>
On the other hand, for feature learning in word2vec we do not need a full
......@@ -135,7 +135,7 @@ same context. We illustrate this below for a CBOW model. For skip-gram the
direction is simply inverted.
<div style="width:60%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/nce-nplm.png" alt>
<img style="width:100%" src="../../images/nce-nplm.png" alt>
</div>
Mathematically, the objective (for each example) is to maximize
......@@ -232,7 +232,7 @@ below (see also for example
[Mikolov et al., 2013](http://www.aclweb.org/anthology/N13-1090)).
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/linear-relationships.png" alt>
<img style="width:100%" src="../../images/linear-relationships.png" alt>
</div>
This explains why these vectors are also useful as features for many canonical
......@@ -329,7 +329,7 @@ After training has finished we can visualize the learned embeddings using
t-SNE.
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/tsne.png" alt>
<img style="width:100%" src="../../images/tsne.png" alt>
</div>
Et voila! As expected, words that are similar end up clustering nearby each
......
# Description:
# Python support for TensorFlow.
package(default_visibility = ["//tensorflow:internal"])
package(
default_visibility = ["//tensorflow:internal"],
)
licenses(["notice"]) # Apache 2.0
......@@ -49,13 +51,23 @@ py_library(
)
py_tests(
name = "platform_tests",
name = "default_platform_tests",
srcs = glob(["platform/default/*_test.py"]),
additional_deps = [
":platform",
":platform_test",
],
prefix = "platform",
prefix = "default_platform",
)
py_tests(
name = "google_platform_tests",
srcs = glob(["platform/google/*_test.py"]),
additional_deps = [
":platform",
":platform_test",
],
prefix = "google_platform",
)
cc_library(
......@@ -843,6 +855,7 @@ cpu_only_kernel_test_list = glob([
"kernel_tests/random_shuffle_queue_test.py",
"kernel_tests/save_restore_ops_test.py",
"kernel_tests/segment_reduction_ops_test.py",
"kernel_tests/self_adjoint_eig_op_test.py",
"kernel_tests/sparse_concat_op_test.py",
"kernel_tests/sparse_matmul_op_test.py",
"kernel_tests/sparse_reorder_op_test.py",
......
......@@ -98,12 +98,6 @@ class Index(Document):
print(" * %s" % link, file=f)
print("", file=f)
# actually include the files right here
print('<div class="sections-order" style="display: none;">\n<!--', file=f)
for filename, _ in self._filename_to_library_map:
print("<!-- %s -->" % filename, file=f)
print("-->\n</div>", file=f)
def collect_members(module_to_name):
"""Collect all symbols from a list of modules.
......
......@@ -151,14 +151,14 @@ class BatchMatmulOpTest(tf.test.TestCase):
self._randComplex([10, 30, 75]), True, True)
def testEmpty(self):
self._compare(np.empty([0, 3, 2]).astype(np.float32),
np.empty([0, 2, 4]).astype(np.float32), False, False)
self._compare(np.empty([3, 2, 0]).astype(np.float32),
np.empty([3, 0, 5]).astype(np.float32), False, False)
self._compare(np.empty([3, 0, 2]).astype(np.float32),
np.empty([3, 2, 5]).astype(np.float32), False, False)
self._compare(np.empty([3, 3, 2]).astype(np.float32),
np.empty([3, 2, 0]).astype(np.float32), False, False)
self._compare(np.zeros([0, 3, 2]).astype(np.float32),
np.zeros([0, 2, 4]).astype(np.float32), False, False)
self._compare(np.zeros([3, 2, 0]).astype(np.float32),
np.zeros([3, 0, 5]).astype(np.float32), False, False)
self._compare(np.zeros([3, 0, 2]).astype(np.float32),
np.zeros([3, 2, 5]).astype(np.float32), False, False)
self._compare(np.zeros([3, 3, 2]).astype(np.float32),
np.zeros([3, 2, 0]).astype(np.float32), False, False)
class BatchMatmulGradientTest(tf.test.TestCase):
......
......@@ -115,7 +115,7 @@ class SumReductionTest(tf.test.TestCase):
self._compareAll(np_arr, [1, 2, 3, 4])
self._compareAll(np_arr, [0, 1, 2, 3, 4])
# Simple tests for various tf.
# Simple tests for various types.
def testDoubleReduce1D(self):
np_arr = np.arange(1, 6).reshape([5]).astype(np.float64)
self._compare(np_arr, [], False)
......@@ -126,6 +126,11 @@ class SumReductionTest(tf.test.TestCase):
self._compare(np_arr, [], False)
self._compare(np_arr, [0], False)
def testComplex64Reduce1D(self):
np_arr = np.arange(1, 6).reshape([5]).astype(np.complex64)
self._compare(np_arr, [], False)
self._compare(np_arr, [0], False)
def testInvalidIndex(self):
np_arr = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
input_tensor = tf.convert_to_tensor(np_arr)
......
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.self_adjoint_eig."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order,wildcard-import,unused-import
import tensorflow.python.platform
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
class SelfAdjointEigOpTest(tf.test.TestCase):
def _testEigs(self, x, d, tf_ans, use_gpu=False):
np_eig_val, np_eig_vec = np.linalg.eig(x)
# First check the eigenvalues
self.assertAllClose(sorted(np_eig_val), sorted(tf_ans[0, :]))
# need to make things canonical. This test may still fail in case there are
# two equal eigenvalues, so that there is indeterminacy in the eigenvectors.
# For now, assume that we will only test matrices with distinct eigenvalues.
np_arg = np.argsort(np_eig_val)
tf_arg = np.argsort(tf_ans[0, :])
np_eig_vecs_sorted = np.array([np_eig_vec[:, i] for i in np_arg]).T
tf_eig_vecs_sorted = np.array([tf_ans[1:, i] for i in tf_arg]).T
np_eig_vecs_signed_sorted = np.array([np_eig_vecs_sorted[:, i] *
np.sign(np_eig_vecs_sorted[0, i])
for i in xrange(d)]).T
tf_eig_vecs_signed_sorted = np.array([tf_eig_vecs_sorted[:, i] *
np.sign(tf_eig_vecs_sorted[0, i])
for i in xrange(d)]).T
self.assertAllClose(np_eig_vecs_signed_sorted, tf_eig_vecs_signed_sorted)
def _compareSelfAdjointEig(self, x, use_gpu=False):
with self.test_session() as sess:
tf_eig = tf.self_adjoint_eig(tf.constant(x))
tf_eig_out = sess.run([tf_eig])[0]
d, _ = x.shape
self.assertEqual([d+1, d], tf_eig.get_shape().dims)
self._testEigs(x, d, tf_eig_out, use_gpu)
def _compareBatchSelfAdjointEigRank3(self, x, use_gpu=False):
with self.test_session() as sess:
tf_eig = tf.batch_self_adjoint_eig(tf.constant(x))
tf_out = sess.run([tf_eig])[0]
dlist = x.shape
d = dlist[-2]
self.assertEqual([d+1, d], tf_eig.get_shape().dims[-2:])
# not testing the values.
self.assertEqual(dlist[0], tf_eig.get_shape().dims[0])
for i in xrange(dlist[0]):
self._testEigs(x[i], d, tf_out[i])
def testBasic(self):
self._compareSelfAdjointEig(
np.array([[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]))
def testBatch(self):
simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
self._compareBatchSelfAdjointEigRank3(simple_array)
self._compareBatchSelfAdjointEigRank3(np.vstack((simple_array, simple_array)))
odd_sized_array = np.array([[[3., 0., 1.], [0., 2., -2.], [1., -2., 3.]]])
self._compareBatchSelfAdjointEigRank3(
np.vstack((odd_sized_array, odd_sized_array)))
# Generate random positive-definite matrices.
matrices = np.random.rand(10, 5, 5)
for i in xrange(10):
matrices[i] = np.dot(matrices[i].T, matrices[i])
self._compareBatchSelfAdjointEigRank3(matrices)
def testNonSquareMatrix(self):
with self.assertRaises(ValueError):
tf.self_adjoint_eig(tf.constant(np.array([[1., 2., 3.], [3., 4., 5.]])))
def testWrongDimensions(self):
tensor3 = tf.constant([1., 2.])
with self.assertRaises(ValueError):
tf.self_adjoint_eig(tensor3)
if __name__ == "__main__":
tf.test.main()
......@@ -79,3 +79,24 @@ def _BatchMatrixInverseShape(op):
# The matrices in the batch must be square.
input_shape[-1].assert_is_compatible_with(input_shape[-2])
return [input_shape]
@ops.RegisterShape("SelfAdjointEig")
def _SelfAdjointEigShape(op):
input_shape = op.inputs[0].get_shape().with_rank(2)
# The matrix must be square.
input_shape[0].assert_is_compatible_with(input_shape[1])
d = input_shape.dims[0]
out_shape = tensor_shape.TensorShape([d+1, d])
return [out_shape]
@ops.RegisterShape("BatchSelfAdjointEig")
def _BatchSelfAdjointEigShape(op):
input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
# The matrices in the batch must be square.
input_shape[-1].assert_is_compatible_with(input_shape[-2])
dlist = input_shape.dims
dlist[-2] += 1
out_shape = tensor_shape.TensorShape(dlist)
return [out_shape]
......@@ -70,6 +70,9 @@ mathematical functions for matrices to your graph.
@@cholesky
@@batch_cholesky
@@self_adjoint_eig
@@batch_self_adjoint_eig
## Complex Number Functions
TensorFlow provides several operations that you can use to add complex number
......
......@@ -711,7 +711,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
unnormalized statistical models]
(http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
Also see our [Candidate Sampling Algorithms Reference]
(http://www.tensorflow.org/extras/candidate_sampling.pdf)
(../../extras/candidate_sampling.pdf)
Note: In the case where `num_true` > 1, we assign to each target class
the target probability 1 / `num_true` so that the target probabilities
......@@ -743,7 +743,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
`True`, this is a "Sampled Logistic" loss instead of NCE, and we are
learning to generate log-odds instead of log probabilities. See
our [Candidate Sampling Algorithms Reference]
(http://www.tensorflow.org/extras/candidate_sampling.pdf).
(../../extras/candidate_sampling.pdf).
Default is False.
name: A name for the operation (optional).
......@@ -782,7 +782,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
expression `tf.nn.softmax(tf.matmul(inputs, weights) + biases)`.
See our [Candidate Sampling Algorithms Reference]
(http://www.tensorflow.org/extras/candidate_sampling.pdf)
(../../extras/candidate_sampling.pdf)
Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
......
......@@ -36,10 +36,15 @@ def load_resource(path):
Raises:
IOError: If the path is not found, or the resource can't be opened.
"""
path = os.path.join('tensorflow', path)
tensorflow_root = (
os.path.join(
os.path.dirname(__file__), os.pardir, os.pardir,
os.pardir))
path = os.path.join(tensorflow_root, path)
path = os.path.abspath(path)
try:
with open(path, 'rb') as f:
return f.read()
except IOError as e:
logging.warning('IOError %s on path %s', e, path)
raise e
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.platform import googletest
from tensorflow.python.platform.default import _resource_loader as resource_loader
class DefaultResourceLoaderTest(googletest.TestCase):
def test_exception(self):
with self.assertRaises(IOError):
resource_loader.load_resource("/fake/file/path/dne")
if __name__ == "__main__":
googletest.main()
......@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import os
import tempfile
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
......@@ -32,8 +33,7 @@ class EventFileLoaderTest(test_util.TensorFlowTestCase):
b'\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d')
def _WriteToFile(self, filename, data):
path = os.path.join(self.get_temp_dir(), filename)
with open(path, 'ab') as f:
with open(filename, 'ab') as f:
f.write(data)
def _LoaderForTestFile(self, filename):
......@@ -41,36 +41,41 @@ class EventFileLoaderTest(test_util.TensorFlowTestCase):
os.path.join(self.get_temp_dir(), filename))
def testEmptyEventFile(self):
self._WriteToFile('empty_event_file', b'')
loader = self._LoaderForTestFile('empty_event_file')
filename = tempfile.NamedTemporaryFile().name
self._WriteToFile(filename, b'')
loader = self._LoaderForTestFile(filename)
self.assertEqual(len(list(loader.Load())), 0)
def testSingleWrite(self):
self._WriteToFile('single_event_file', EventFileLoaderTest.RECORD)
loader = self._LoaderForTestFile('single_event_file')
filename = tempfile.NamedTemporaryFile().name
self._WriteToFile(filename, EventFileLoaderTest.RECORD)
loader = self._LoaderForTestFile(filename)
events = list(loader.Load())
self.assertEqual(len(events), 1)
self.assertEqual(events[0].wall_time, 1440183447.0)
self.assertEqual(len(list(loader.Load())), 0)
def testMultipleWrites(self):
self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD)
loader = self._LoaderForTestFile('staggered_event_file')
filename = tempfile.NamedTemporaryFile().name
self._WriteToFile(filename, EventFileLoaderTest.RECORD)
loader = self._LoaderForTestFile(filename)
self.assertEqual(len(list(loader.Load())), 1)
self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD)
self._WriteToFile(filename, EventFileLoaderTest.RECORD)
self.assertEqual(len(list(loader.Load())), 1)
def testMultipleLoads(self):
self._WriteToFile('multiple_loads_event_file', EventFileLoaderTest.RECORD)
loader = self._LoaderForTestFile('multiple_loads_event_file')
filename = tempfile.NamedTemporaryFile().name
self._WriteToFile(filename, EventFileLoaderTest.RECORD)
loader = self._LoaderForTestFile(filename)
loader.Load()
loader.Load()
self.assertEqual(len(list(loader.Load())), 1)
def testMultipleWritesAtOnce(self):
self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD)
self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD)
loader = self._LoaderForTestFile('staggered_event_file')
filename = tempfile.NamedTemporaryFile().name
self._WriteToFile(filename, EventFileLoaderTest.RECORD)
self._WriteToFile(filename, EventFileLoaderTest.RECORD)
loader = self._LoaderForTestFile(filename)
self.assertEqual(len(list(loader.Load())), 2)
......
......@@ -14,9 +14,9 @@ subject to an additional IP rights grant found at http://polymer.github.io/PATEN
<meta name="viewport" content="width=device-width, minimum-scale=1.0, initial-scale=1.0, user-scalable=yes">
<title>tf-graph Demo</title>
<!-- Libraries that should be imported in TensorBoard when the Graph visualizer ports to TensorBoard -->
<script src="bower_components/webcomponentsjs/webcomponents-lite.min.js"></script>
<script src="bower_components/es6-promise/promise.min.js"></script>
<link rel="import" href="components/tf-graph/demo/tf-graph-demo.html">
<script src="../../webcomponentsjs/webcomponents-lite.min.js"></script>
<script src="../../es6-promise/promise.min.js"></script>
<link rel="import" href="tf-graph-demo.html">
<style>
html {
width: 100%;
......
......@@ -112,11 +112,6 @@ class ThreadedHTTPServer(SocketServer.ThreadingMixIn,
def main(unused_argv=None):
# Change current working directory to tensorflow/'s parent directory.
server_root = os.path.join(os.path.dirname(__file__),
os.pardir, os.pardir)
os.chdir(server_root)
if FLAGS.debug:
logging.set_verbosity(logging.DEBUG)
......
FROM tensorflow:ci_build.cpu
MAINTAINER Jan Prach <jendap@google.com>
# Install Android SDK.
ENV ANDROID_SDK_FILENAME android-sdk_r24.4.1-linux.tgz
ENV ANDROID_SDK_URL http://dl.google.com/android/${ANDROID_SDK_FILENAME}
ENV ANDROID_API_LEVEL 23
ENV ANDROID_BUILD_TOOLS_VERSION 23.0.2
ENV ANDROID_HOME /opt/android-sdk-linux
ENV PATH ${PATH}:${ANDROID_HOME}/tools:${ANDROID_HOME}/platform-tools
RUN cd /opt && \
wget -q ${ANDROID_SDK_URL} && \
tar -xzf ${ANDROID_SDK_FILENAME} && \
rm ${ANDROID_SDK_FILENAME} && \
echo y | android update sdk --no-ui -a --filter tools,platform-tools,android-${ANDROID_API_LEVEL},build-tools-${ANDROID_BUILD_TOOLS_VERSION}
# Install Android NDK.
ENV ANDROID_NDK_FILENAME android-ndk-r10e-linux-x86_64.bin
ENV ANDROID_NDK_URL http://dl.google.com/android/ndk/${ANDROID_NDK_FILENAME}
ENV ANDROID_NDK_HOME /opt/android-ndk
ENV PATH ${PATH}:${ANDROID_NDK_HOME}
RUN cd /opt && \
wget -q ${ANDROID_NDK_URL} && \
chmod +x ${ANDROID_NDK_FILENAME} && \
./${ANDROID_NDK_FILENAME} -o/opt && \
rm ${ANDROID_NDK_FILENAME} && \
bash -c 'ln -s /opt/android-ndk-* /opt/android-ndk'
FROM ubuntu:14.04
MAINTAINER Jan Prach <jendap@google.com>
# Install dependencies for bazel.
RUN apt-get update && apt-get install -y \
g++ \
pkg-config \
python-dev \
python-numpy \
python-pip \
software-properties-common \
swig \
unzip \
wget \
zip \
zlib1g-dev \
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Install openjdk 8 for bazel from PPA (it is not available in 14.04).
RUN add-apt-repository -y ppa:openjdk-r/ppa && \
apt-get update && \
apt-get install -y openjdk-8-jdk openjdk-8-jre-headless && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Install the most recent bazel release.
ENV BAZEL_VERSION 0.1.1
RUN mkdir /bazel && \
cd /bazel && \
wget https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
wget -O /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE.txt && \
chmod +x bazel-*.sh && \
./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
# Enable bazel auto completion.
RUN echo "source /usr/local/lib/bazel/bin/bazel-complete.bash" >> ~/.bashrc
# Running bazel inside a `docker build` command causes trouble, cf:
# https://github.com/bazelbuild/bazel/issues/134
# The easiest solution is to set up a bazelrc file forcing --batch.
RUN echo "startup --batch" >>/root/.bazelrc
# Similarly, we need to workaround sandboxing issues:
# https://github.com/bazelbuild/bazel/issues/418
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/root/.bazelrc
# Force bazel output to use colors (good for jenkins).
RUN echo "common --color=yes" >>/root/.bazelrc
ENV BAZELRC /root/.bazelrc
FROM tensorflow:ci_build.cpu
MAINTAINER Jan Prach <jendap@google.com>
# Install Cuda.
RUN cd /tmp && \
wget http://developer.download.nvidia.com/compute/cuda/7_0/Prod/local_installers/cuda_7.0.28_linux.run && \
chmod +x *.run && ./cuda_*_linux.run -extract=`pwd` && \
./NVIDIA-Linux-x86_64-*.run -s --no-kernel-module && \
./cuda-linux64-rel-*.run -noprompt && \
rm -rf *
# Set up CUDA variables in .bashrc
RUN echo "CUDA_PATH=/usr/local/cuda" >>~/.bash_profile && \
echo "LD_LIBRARY_PATH=/usr/local/cuda/lib64:/tensorflow_extra_deps/cudnn-6.5-linux-x64-v2" >>~/.bash_profile
# Set up cuda variables.
ENV CUDA_PATH /usr/local/cuda
ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:/tensorflow_extra_deps/cudnn-6.5-linux-x64-v2
# Set up variables fo tensorflow to use cuda.
ENV CUDA_TOOLKIT_PATH /usr/local/cuda
ENV CUDNN_INSTALL_PATH /tensorflow_extra_deps/cudnn-6.5-linux-x64-v2
# TensorFlow.org Continuous Integration
This directory contains all the files and setup instructions to run
continuous integration [ci.tensorflow.org](http://ci.tensorflow.org).
## How it works
We use [jenkins](https://jenkins-ci.org/) as our continuous integration.
It is running at [ci.tensorflow.org](http://ci.tensorflow.org).
All the jobs are run within [docker](http://www.docker.com/) containers.
Builds can be triggered by push to master, push a change set or manually.
The build started in jenkins will first pull the git tree. Then jenkins builds
a docker container (using one of those Dockerfile.* files in this directory).
The build itself is run within the container itself.
Source tree lives in jenkins job workspace. Docker container for jenkins
are transient - deleted after the build. Containers build very fast thanks
to docker caching. Individual builds are fast thanks to bazel caching.
## Implementation Details
* The unusual `bazel-user-cache-for-docker` directory is mapped to docker
container performing the build using docker's --volume parameter.
This way we cache bazel output between builds.
* The `$HOME/.tensorflow_extra_deps` directory contains
[cudnn](https://developer.nvidia.com/cudnn).
Unfortunatelly this require you to agree a license to download.
* The builds directory hithin this folder contains shell scripts to run within
the container. They essentially contains workarounds for current limitations
of bazel.
## Run It Yourself
1. Install [Docker](http://www.docker.com/). Follow instructions
[on the Docker site](https://docs.docker.com/installation/).
2. Clone tensorflow repository.
```bash
git clone https://github.com/tensorflow/tensorflow.git
```
3. Go to tensorflow directory
```bash
cd tensorflow
```
4. Build what you want, for example
```bash
tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/...
```
**Note**: For GPU you have to create `$HOME/.tensorflow_extra_deps` and manually
install there required dependencies (i.e. cudnn) for which you have to agree
to licences manually.
#### CUDNN
For GPU download the [cudnn](https://developer.nvidia.com/cudnn).
You will download `cudnn-6.5-linux-x64-v2.tgz`. Run
```bash
mkdir -p $HOME/.tensorflow_extra_deps
tar xzf cudnn-6.5-linux-x64-v2.tgz -C $HOME/.tensorflow_extra_deps
```
## Jobs
The jobs run by [ci.tensorflow.org](http://ci.tensorflow.org) include following:
```bash
# Note: You can run the following one-liners yourself if you have Docker.
# build and run cpu tests
tensorflow/tools/ci_build/ci_build.sh CPU bazel test --test_timeout=1800 //tensorflow/...
# build gpu
tensorflow/tools/ci_build/ci_build.sh GPU tensorflow/tools/ci_build/builds/gpu.sh
# build pip with gpu support
tensorflow/tools/ci_build/ci_build.sh GPU tensorflow/tools/ci_build/builds/gpu_pip.sh
# build android example app
tensorflow/tools/ci_build/ci_build.sh ANDROID tensorflow/tools/ci_build/builds/android.sh
```
**Note**: The set of jobs and how they are triggered is still evolving.
There are builds for master branch on cpu, gpu and android. There is a build
for incoming gerrit changes. Gpu tests and benchmark are coming soon. Check
[ci.tensorflow.org](http://ci.tensorflow.org) for current jobs.
#!/usr/bin/env bash
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e
# Download model file.
# Note: This is workaround. This should be done by bazel.
model_file_name="inception5h.zip"
tmp_model_file_name="${HOME}/.cache/tensorflow_models/${model_file_name}"
mkdir -p $(dirname ${tmp_model_file_name})
[ -e "${tmp_model_file_name}" ] || wget -c "https://storage.googleapis.com/download.tensorflow.org/models/${model_file_name}" -O "${tmp_model_file_name}"
unzip -o "${tmp_model_file_name}" -d tensorflow/examples/android/assets/
# Modify the WORKSPACE file.
# Note: This is workaround. This should be done by bazel.
if grep -q '^android_sdk_repository' WORKSPACE && grep -q '^android_ndk_repository' WORKSPACE; then
echo "You probably have your WORKSPACE file setup for Android."
else
if [ -z "${ANDROID_API_LEVEL}" -o -z "${ANDROID_BUILD_TOOLS_VERSION}" ] || \
[ -z "${ANDROID_HOME}" -o -z "${ANDROID_NDK_HOME}" ]; then
echo "ERROR: Your WORKSPACE file does not seems to have proper android"
echo " configuration and not all the environment variables expected"
echo " inside ci_build android docker container are set."
echo " Please configure it manually. See: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android/README.md"
else
cat << EOF >> WORKSPACE
android_sdk_repository(
name = "androidsdk",
api_level = ${ANDROID_API_LEVEL},
build_tools_version = "${ANDROID_BUILD_TOOLS_VERSION}",
path = "${ANDROID_HOME}",
)
android_ndk_repository(
name="androidndk",
path="${ANDROID_NDK_HOME}",
api_level=21)
EOF
fi
fi
# Build Android demo app.
bazel build -c opt --copt=-mfpu=neon //tensorflow/examples/android:tensorflow_demo
# Cleanup workarounds.
rm -rf tensorflow/examples/android/assets/
#!/usr/bin/env bash
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e
export TF_NEED_CUDA=1
./configure
bazel build -c opt --config=cuda //tensorflow/...
#!/usr/bin/env bash
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e
export TF_NEED_CUDA=1
./configure
bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
rm -rf /root/.cache/tensorflow-pip
bazel-bin/tensorflow/tools/pip_package/build_pip_package /root/.cache/tensorflow-pip
#!/usr/bin/env bash
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Get the command line arguments.
CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' )
shift 1
COMMAND=("$@")
# Validate command line arguments.
if [ "$#" -lt 1 ] || [[ ! "${CONTAINER_TYPE}" =~ ^(cpu|gpu|android)$ ]]; then
>&2 echo "Usage: $(basename $0) CONTAINER_TYPE COMMAND"
>&2 echo " CONTAINER_TYPE can be 'CPU' or 'GPU'"
>&2 echo " COMMAND is a command (with arguments) to run inside"
>&2 echo " the container."
>&2 echo ""
>&2 echo "Example (run all tests on CPU):"
>&2 echo "$0 CPU bazel test //tensorflow/..."
exit 1
fi
# Figure out the directory where this script is.
SCRIPT_DIR=$( cd ${0%/*} && pwd -P )
# Helper function to traverse directories up until given file is found.
function upsearch () {
test / == "$PWD" && return || \
test -e "$1" && echo "$PWD" && return || \
cd .. && upsearch "$1"
}
# Set up WORKSPACE and BUILD_TAG. Jenkins will set them for you or we pick
# reasonable defaults if you run it outside of Jenkins.
WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}"
BUILD_TAG="${BUILD_TAG:-tf_ci}"
# Additional configuration. You can customize it by modifying
# env variable.
EXTRA_DEPS_DIR="${EXTRA_DEPS_DIR:-${HOME}/.tensorflow_extra_deps}"
# Print arguments.
echo "CONTAINER_TYPE: ${CONTAINER_TYPE}"
echo "COMMAND: ${COMMAND[@]}"
echo "WORKSAPCE: ${WORKSPACE}"
echo "BUILD_TAG: ${BUILD_TAG}"
echo " (docker container name will be ${BUILD_TAG}.${CONTAINER_TYPE})"
echo "EXTRA_DEPS_DIR: ${EXTRA_DEPS_DIR}"
echo ""
# Build the docker containers.
echo "Building CPU container (${BUILD_TAG}.cpu)..."
docker build -t ${BUILD_TAG}.cpu -f ${SCRIPT_DIR}/Dockerfile.cpu ${SCRIPT_DIR}
if [ "${CONTAINER_TYPE}" != "cpu" ]; then
echo "Building container ${BUILD_TAG}.${CONTAINER_TYPE}..."
tmp_dockerfile="${SCRIPT_DIR}/Dockerfile.${CONTAINER_TYPE}.${BUILD_TAG}"
# we need to generate temporary dockerfile with overwritten FROM directive
sed "s/^FROM .*/FROM ${BUILD_TAG}.cpu/" \
${SCRIPT_DIR}/Dockerfile.${CONTAINER_TYPE} > ${tmp_dockerfile}
docker build -t ${BUILD_TAG}.${CONTAINER_TYPE} \
-f ${tmp_dockerfile} ${SCRIPT_DIR}
rm ${tmp_dockerfile}
fi
# Run the command inside the container.
echo "Running '${COMMAND[@]}' inside ${BUILD_TAG}.${CONTAINER_TYPE}..."
mkdir -p ${WORKSPACE}/bazel-user-cache-for-docker
docker run \
-v ${WORKSPACE}/bazel-user-cache-for-docker:/root/.cache \
-v ${WORKSPACE}:/tensorflow \
-v ${EXTRA_DEPS_DIR}:/tensorflow_extra_deps \
-w /tensorflow \
${BUILD_TAG}.${CONTAINER_TYPE} \
"${COMMAND[@]}"
#!/bin/bash
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e -o errexit
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册