提交 bf6b536b 编写于 作者: V Vijay Vasudevan

TensorFlow: Upstream changes to git.

Change 109240606
	Fix typo
Change 109240358
	Fix bug in Concat's shape inference due to legacy scalar handling.

	The shape function was inadvertently converting outputs of unknown
	shape (rank=None) to vectors of unknown length (rank=1), due to
	inability to distinguish between legacy scalars and vectors, because
	`max(1, None)` is 1.
Change 109237152
	Remove numarray requirement in python_config.
Change 109234003
	Fix typo in elu documentation.
Change 109232946
	Python must now be configured via ./configure script
Change 109232134
	Backported fixes to the tensor comparison operators from the public Eigen repository
Change 109231761
	Test invalid inputs to softmax_cross_entropy_with_logits.
Change 109230218
	Backported fixes to the tensor comparison operators from the public Eigen repository
Change 109229915
	Correct comments in seq2seq to show the right input types for embedding models.
	(Thanks to hugman@github for bringing this up.)
Change 109229118
	Fix resize_images example in documentation and allow resize_images to run on a single image with partially-known shape.
Change 109228940
	Fix demo and node add/remove button spacing
Change 109227909
	Include Elu in the NN docs.
Change 109227059
	Adds variable_op_scope and makes variable_scope always add a name_scope.

	This creates an op scope for variables that makes it easy to create independent
	operations with a default name by making that name unique for the current scope
	and it allows explicit names that are not made unique.

Change 109224492
	Streamline yuv -> rgb conversion to be done in one pass in native code.

	The entire process now takes ~2ms (including the ByteBuffer.get() calls), down from 10+ ms when the arrays were being interleaved in Java prior to conversion.

	Also abstracting common yuv->rgb color conversion into helper method.
Change 109224389
	Add ability to move nodes in and out of auxiliary nodes in graph.
Change 109217177
	Update generated Op docs.
Change 109215030
	Implementation of the ELU activation function: http://arxiv.org/abs/1511.07289
Change 109209848
	When GPUBFCAllocator runs out of memory, also log a summary
	of chunks in use by size.
Change 109206569
	Switched to the public version of the Eigen::sign method since it supports complex numbers.
Change 109199813
	Modify tensorflow.SequenceExample to support multiple-length sequences.

Base CL: 109241553
上级 fa095c5d
#!/bin/bash
## Set up python-related environment settings
while true; do
fromuser=""
if [ -z "$PYTHON_BIN_PATH" ]; then
default_python_bin_path=$(which python)
read -p "Please specify the location of python. [Default is $default_python_bin_path]: " PYTHON_BIN_PATH
fromuser="1"
if [ -z "$PYTHON_BIN_PATH" ]; then
PYTHON_BIN_PATH=$default_python_bin_path
fi
fi
if [ -e "$PYTHON_BIN_PATH" ]; then
break
fi
echo "Invalid python path. ${PYTHON_BIN_PATH} cannot be found" 1>&2
if [ -z "$fromuser" ]; then
exit 1
fi
PYTHON_BIN_PATH=""
# Retry
done
# Invoke python_config and set up symlinks to python includes
(./util/python/python_config.sh --setup "$PYTHON_BIN_PATH";) || exit -1
## Set up Cuda-related environment settings
while [ "$TF_NEED_CUDA" == "" ]; do
read -p "Do you wish to build TensorFlow with GPU support? [y/n] " INPUT
read -p "Do you wish to build TensorFlow with GPU support? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo -e "GPU support will be enabled for TensorFlow\n"; TF_NEED_CUDA=1;;
[Nn]* ) echo -e "No GPU support will be enabled for TensorFlow\n"; TF_NEED_CUDA=0;;
[Yy]* ) echo "GPU support will be enabled for TensorFlow"; TF_NEED_CUDA=1;;
[Nn]* ) echo "No GPU support will be enabled for TensorFlow"; TF_NEED_CUDA=0;;
"" ) echo "No GPU support will be enabled for TensorFlow"; TF_NEED_CUDA=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
......@@ -77,7 +103,7 @@ CUDNN_INSTALL_PATH="$CUDNN_INSTALL_PATH"
EOF
function UnofficialSetting() {
echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n"
echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n" 1>&2
# Configure the compute capabilities that TensorFlow builds for.
# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
......
......@@ -342,6 +342,7 @@ size_t GPUBFCAllocator::AllocatedSize(void* ptr) {
void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
// For each bin: tally up the total number of chunks and bytes.
// Note that bins hold only free chunks.
for (auto bit : bins_) {
Bin* b = bit.second;
......@@ -389,6 +390,24 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
LOG(INFO) << c->DebugString(true);
}
}
}
// Next show the the chunks that are in use, and also summarize their
// number by size.
std::map<size_t, int> in_use_by_size;
for (auto& it : ptr_to_chunk_map_) {
const Chunk& c = *it.second;
in_use_by_size[c.size]++;
LOG(INFO) << "Chunk at " << it.first << " of size " << c.size;
}
LOG(INFO) << " Summary of in-use Chunks by size: ";
size_t total_bytes = 0;
for (auto& it : in_use_by_size) {
LOG(INFO) << it.second << " Chunks of size " << it.first << " totalling "
<< strings::HumanReadableNumBytes(it.first * it.second);
total_bytes += (it.first * it.second);
}
LOG(INFO) << "Sum Total of in-use chunks: "
<< strings::HumanReadableNumBytes(total_bytes);
}
} // namespace tensorflow
......@@ -115,14 +115,14 @@ class GPUBFCAllocator : public VisitableAllocator {
};
Chunk* AllocateNewChunk(size_t num_bytes);
void SplitChunk(Chunk* c, size_t num_bytes);
void Merge(Chunk* c1, Chunk* c2);
void FreeAndMaybeCoalesce(Chunk* c);
void InsertFreeChunkIntoBin(Chunk* c);
void SplitChunk(Chunk* c, size_t num_bytes) EXCLUSIVE_LOCKS_REQUIRED(lock_);
void Merge(Chunk* c1, Chunk* c2) EXCLUSIVE_LOCKS_REQUIRED(lock_);
void FreeAndMaybeCoalesce(Chunk* c) EXCLUSIVE_LOCKS_REQUIRED(lock_);
void InsertFreeChunkIntoBin(Chunk* c) EXCLUSIVE_LOCKS_REQUIRED(lock_);
void RemoveFreeChunkFromBin(Chunk* c);
void DeleteChunk(Chunk* c);
void DeleteChunk(Chunk* c) EXCLUSIVE_LOCKS_REQUIRED(lock_);
void DumpMemoryLog(size_t num_bytes);
void DumpMemoryLog(size_t num_bytes) EXCLUSIVE_LOCKS_REQUIRED(lock_);
// A Bin is a collection of similar-sized free chunks.
struct Bin {
......@@ -163,7 +163,7 @@ class GPUBFCAllocator : public VisitableAllocator {
// Structures mutable after construction
mutable mutex lock_;
// Chunk * owned.
std::unordered_map<void*, Chunk*> ptr_to_chunk_map_;
std::unordered_map<void*, Chunk*> ptr_to_chunk_map_ GUARDED_BY(lock_);
// Called once on each region, ASAP.
std::vector<Visitor> region_visitors_;
......
......@@ -7,7 +7,21 @@ import "tensorflow/core/example/feature.proto";
package tensorflow;
// Example for a movie recommendation application:
// An Example is a mostly-normalized data format for storing data for
// training and inference. It contains a key-value store (features); where
// each key (string) maps to a Feature message (which is oneof packed BytesList,
// FloatList, or Int64List). This flexible and compact format allows the
// storage of large amounts of typed data, but requires that the data shape
// and use be determined by the configuration files and parsers that are used to
// read and write this format. That is, the Example is mostly *not* a
// self-describing format. In TensorFlow, Examples are read in row-major
// format, so any configuration that describes data with rank-2 or above
// should keep this in mind. For example, to store an M x N matrix of Bytes,
// the BytesList must contain M*N bytes, with M rows of N contiguous values
// each. That is, the BytesList value must store the matrix as:
// .... row 0 .... .... row 1 .... // ........... // ... row M-1 ....
//
// An Example for a movie recommendation application:
// features {
// feature {
// key: "age"
......@@ -58,7 +72,7 @@ package tensorflow;
// }
// }
//
// A conformant data set obeys the following conventions:
// A conformant Example data set obeys the following conventions:
// - If a Feature K exists in one example with data type T, it must be of
// type T in all other examples when present. It may be omitted.
// - The number of instances of Feature K list data may vary across examples,
......@@ -72,23 +86,182 @@ message Example {
Features features = 1;
};
// Example representing a ranking instance.
message RankingExample {
Features context = 1;
repeated Features positive = 2;
repeated Features negative = 3;
};
// A SequenceExample is an Example representing one or more sequences, and
// some context. The context contains features which apply to the entire
// example. The feature_lists contain a key, value map where each key is
// associated with a repeated set of Features (a FeatureList).
//
// A SequenceExample for a movie recommendation application:
//
// context: {
// feature: {
// key : "locale"
// value: {
// bytes_list: {
// value: [ "pt_BR" ]
// }
// }
// }
// feature: {
// key : "age"
// value: {
// float_list: {
// value: [ 19.0 ]
// }
// }
// }
// feature: {
// key : "favorites"
// value: {
// bytes_list: {
// value: [ "Majesty Rose", "Savannah Outen", "One Direction" ]
// }
// }
// }
// }
// feature_lists: {
// feature_list: {
// key : "movie_ratings"
// value: {
// feature: {
// float_list: {
// value: [ 4.5 ]
// }
// }
// feature: {
// float_list: {
// value: [ 5.0 ]
// }
// }
// }
// }
// feature_list: {
// key : "movie_names"
// value: {
// feature: {
// bytes_list: {
// value: [ "The Shawshank Redemption" ]
// }
// }
// feature: {
// bytes_list: {
// value: [ "Fight Club" ]
// }
// }
// }
// }
// }
//
// A conformant SequenceExample data set obeys the following conventions:
//
// Context:
// - All conformant context features K must obey the same conventions as
// a conformant Example's features (see above).
// Feature lists:
// - A FeatureList L may be missing in an example; it is up to the
// parser configuration to determine if this is allowed or considered
// an empty list (zero length).
// - If a FeatureList L exists, it may be empty (zero length).
// - If a FeatureList L is non-empty, all features within the FeatureList
// must have data type T, and all features within the FeatureList must
// have the same size.
// - If a FeatureList L exists in one example with data type T,
// it must be of type T in all other examples when present.
// - If a FeatureList L exists in one example having features' sizes all S,
// these sizes must be S in all other examples when present.
//
// Examples of conformant and non-conformant examples' FeatureLists:
//
// Conformant FeatureLists:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
//
// Non-conformant FeatureLists (mismatched types):
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { int64_list: { value: [ 5 ] } } }
// } }
//
// Non-conformant FeatureLists (mismatched sizes):
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0, 6.0 ] } } }
// } }
//
// Conformant pair of SequenceExample
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } }
// feature: { float_list: { value: [ 2.0 ] } } }
// } }
//
// Conformant pair of SequenceExample
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { }
// } }
//
// Conditionally conformant pair of SequenceExample, the parser configuration
// determines if the second feature_lists is consistent (zero-length) or
// invalid (missing "movie_ratings"):
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { }
//
// Non-conformant pair of SequenceExample (mismatched types)
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { int64_list: { value: [ 4 ] } }
// feature: { int64_list: { value: [ 5 ] } }
// feature: { int64_list: { value: [ 2 ] } } }
// } }
//
// Non-conformant pair of SequenceExample (mismatched sizes)
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.0, 5.0 ] } }
// feature: { float_list: { value: [ 5.0, 3.0 ] } }
// } }
// Example representing a sequence.
// The context contains features which apply to the entire sequence.
// Each element in example represents an entry in the sequence.
message SequenceExample {
Features context = 1;
repeated Features features = 2;
FeatureLists feature_lists = 2;
};
// Example representing a list of feature maps.
// The context contains features which apply to all feature maps.
message InferenceExample {
Features context = 1;
repeated Features features = 2;
......
......@@ -6,7 +6,8 @@
// - float
// - int64
//
// Base features are contained in Lists which may hold zero or more values.
// A Feature contains Lists which may hold zero or more values. These
// lists are the base values BytesList, FloatList, Int64List.
//
// Features are organized into categories by name. The Features message
// contains the mapping from name to Feature.
......@@ -50,12 +51,25 @@
// value: 9.99
// }}
// }
//
syntax = "proto3";
// option cc_enable_arenas = true;
package tensorflow;
// Containers to hold repeated fundamental values.
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
......@@ -70,13 +84,19 @@ message Features {
map<string, Feature> feature = 1;
};
// Containers to hold repeated fundamental features.
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
// Containers for sequential data.
//
// A FeatureList contains lists of Features. These may hold zero or more
// Feature values.
//
// FeatureLists are organized into categories by name. The FeatureLists message
// contains the mapping from name to FeatureList.
//
message FeatureList {
repeated Feature feature = 1;
};
message FeatureLists {
// Map from feature name to feature list.
map<string, FeatureList> feature_list = 1;
};
......@@ -29,16 +29,6 @@ limitations under the License.
namespace Eigen {
namespace internal {
template <typename T>
struct scalar_sign_op {
// TODO(zhifengc): this only works for real types. In theory,
// sign(x) = x / |x| works for both real and complex values.
EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
return T(x > T(0)) - T(x < T(0));
}
};
// TODO(zhifengc): Eigen::internal::pow_impl does not have proper
// EIGEN host/device decoration. We duplicate code here for now.
template <typename T, bool IsInteger>
......
......@@ -312,7 +312,7 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
int32 s = queues_[0].size();
if (closed_ && s < attempt->elements_requested) {
attempt->context->SetStatus(errors::OutOfRange(
"RandomSuffleQueue '", name_, "' is closed and has ",
"RandomShuffleQueue '", name_, "' is closed and has ",
"insufficient elements (requested ",
attempt->elements_requested, ", current size ", s, ")"));
return kComplete;
......
......@@ -42,6 +42,27 @@ class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> {
}
};
template <typename Device, typename T>
class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to ReluOp()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OP_REQUIRES(context, a.IsSameSize(g),
errors::InvalidArgument("g and a must be the same size"));
functor::ReluGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
template <typename Device, typename T>
class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
public:
......@@ -55,13 +76,13 @@ class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
};
template <typename Device, typename T>
class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to ReluOp()
// a (inputs): inputs that were passed to Relu6Op()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
......@@ -69,20 +90,32 @@ class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
Tensor* output) {
OP_REQUIRES(context, a.IsSameSize(g),
errors::InvalidArgument("g and a must be the same size"));
functor::ReluGrad<Device, T> functor;
functor::Relu6Grad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
template <typename Device, typename T>
class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
public:
using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Elu<Device, T> functor;
functor(context->eigen_device<Device>(), input.flat<T>(),
output->flat<T>());
}
};
template <typename Device, typename T>
class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to Relu6Op()
// a (outputs): outputs of the EluOp()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
......@@ -90,58 +123,83 @@ class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
Tensor* output) {
OP_REQUIRES(context, a.IsSameSize(g),
errors::InvalidArgument("g and a must be the same size"));
functor::Relu6Grad<Device, T> functor;
functor::EluGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
#define REGISTER_KERNELS(type) \
#define REGISTER_RELU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ReluOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
Relu6Op<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ReluGradOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
Relu6Op<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
Relu6GradOp<CPUDevice, type>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
#undef REGISTER_RELU_KERNELS
#define REGISTER_ELU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Elu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
EluOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("EluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
EluGradOp<CPUDevice, type>)
// Elu only makes sense with float or double.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
#undef REGISTER_ELU_KERNELS
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void Relu<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Relu<GPUDevice, T>; \
\
template <> \
void ReluGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor backprops); \
\
extern template struct ReluGrad<GPUDevice, T>; \
template <> \
void Relu6<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Relu6<GPUDevice, T>; \
\
template <> \
void Relu6Grad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor backprops); \
extern template struct Relu6Grad<GPUDevice, T>;
#define DECLARE_GPU_SPEC(T) \
template <> \
void Relu<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Relu<GPUDevice, T>; \
\
template <> \
void ReluGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor backprops); \
extern template struct ReluGrad<GPUDevice, T>; \
\
template <> \
void Relu6<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Relu6<GPUDevice, T>; \
\
template <> \
void Relu6Grad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor backprops); \
extern template struct Relu6Grad<GPUDevice, T>; \
\
template <> \
void Elu<GPUDevice, T>::operator()(const GPUDevice& d, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Elu<GPUDevice, T>; \
\
template <> \
void EluGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor activations, \
typename TTypes<T>::Tensor backprops); \
extern template struct EluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
......@@ -151,15 +209,21 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
REGISTER_KERNEL_BUILDER( \
Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
ReluOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
Relu6Op<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
ReluGradOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
Relu6Op<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
Relu6GradOp<GPUDevice, type>)
Relu6GradOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
EluOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
EluGradOp<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
......
......@@ -88,6 +88,40 @@ struct Relu6Grad {
}
};
// Functor used by EluOp to do the computations.
template <typename Device, typename T>
struct Elu {
// Computes Relu activation.
//
// features: any shape.
// activations: same shape as "features".
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor activations) {
// features.constant(?)
activations.device(d) =
(features < static_cast<T>(0))
.select(features.exp() - features.constant(static_cast<T>(1)),
features);
}
};
// Functor used by EluGradOp to do the computations.
template <typename Device, typename T>
struct EluGrad {
// Computes EluGrad backprops.
//
// gradients: gradients backpropagated to the Elu op.
// activations: outputs of the Elu op.
// backprops: gradients to backpropagate to the Elu inputs.
void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
typename TTypes<T>::ConstTensor activations,
typename TTypes<T>::Tensor backprops) {
backprops.device(d) =
(activations < static_cast<T>(0))
.select((activations + static_cast<T>(1)) * gradients, gradients);
}
};
} // namespace functor
} // namespace tensorflow
......
......@@ -29,11 +29,13 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
// Definition of the GPU implementations declared in relu_op.cc.
#define DEFINE_GPU_KERNELS(T) \
template struct functor::Relu<GPUDevice, T>; \
template struct functor::ReluGrad<GPUDevice, T>; \
template struct functor::Relu6<GPUDevice, T>; \
template struct functor::Relu6Grad<GPUDevice, T>;
#define DEFINE_GPU_KERNELS(T) \
template struct functor::Relu<GPUDevice, T>; \
template struct functor::ReluGrad<GPUDevice, T>; \
template struct functor::Relu6<GPUDevice, T>; \
template struct functor::Relu6Grad<GPUDevice, T>; \
template struct functor::Elu<GPUDevice, T>; \
template struct functor::EluGrad<GPUDevice, T>
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
......
......@@ -445,6 +445,31 @@ backprops: The gradients:
`gradients * features * (features > 0) * (features < 6)`.
)doc");
REGISTER_OP("Elu")
.Input("features: T")
.Output("activations: T")
.Attr("T: {float, double}")
.Doc(R"doc(
Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise.
See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
](http://arxiv.org/abs/1511.07289)
)doc");
REGISTER_OP("EluGrad")
.Input("gradients: T")
.Input("outputs: T")
.Output("backprops: T")
.Attr("T: {float, double}")
.Doc(R"doc(
Computes gradients for the exponential linear (Elu) operation.
gradients: The backpropagated gradients to the corresponding Elu operation.
outputs: The outputs of the corresponding Elu operation.
backprops: The gradients: `gradients * (outputs + 1)` if outputs < 0,
`gradients` otherwise.
)doc");
REGISTER_OP("Softplus")
.Input("features: T")
.Output("activations: T")
......
......@@ -2096,6 +2096,58 @@ op {
summary: "Computes the (possibly normalized) Levenshtein Edit Distance."
description: "The inputs are variable-length sequences provided by SparseTensors\n (hypothesis_indices, hypothesis_values, hypothesis_shape)\nand\n (truth_indices, truth_values, truth_shape).\n\nThe inputs are:"
}
op {
name: "Elu"
input_arg {
name: "features"
type_attr: "T"
}
output_arg {
name: "activations"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise."
description: "See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)\n](http://arxiv.org/abs/1511.07289)"
}
op {
name: "EluGrad"
input_arg {
name: "gradients"
description: "The backpropagated gradients to the corresponding Elu operation."
type_attr: "T"
}
input_arg {
name: "outputs"
description: "The outputs of the corresponding Elu operation."
type_attr: "T"
}
output_arg {
name: "backprops"
description: "The gradients: `gradients * (outputs + 1)` if outputs < 0,\n`gradients` otherwise."
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Computes gradients for the exponential linear (Elu) operation."
}
op {
name: "EncodeJpeg"
input_arg {
......
......@@ -38,10 +38,14 @@ IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)(
JNIEnv* env, jclass clazz, jbyteArray input, jintArray output,
jint width, jint height, jboolean halfSize);
JNIEXPORT void JNICALL
IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output,
jint width, jint height);
JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)(
JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v,
jintArray output, jint width, jint height, jint y_row_stride,
jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize);
JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width,
jint height);
JNIEXPORT void JNICALL
IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)(
......@@ -82,10 +86,39 @@ IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)(
env->ReleaseIntArrayElements(output, o, 0);
}
JNIEXPORT void JNICALL
IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output,
jint width, jint height) {
JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)(
JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v,
jintArray output, jint width, jint height, jint y_row_stride,
jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize) {
jboolean inputCopy = JNI_FALSE;
jbyte* const y_buff = env->GetByteArrayElements(y, &inputCopy);
jboolean outputCopy = JNI_FALSE;
jint* const o = env->GetIntArrayElements(output, &outputCopy);
if (halfSize) {
ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast<uint8*>(y_buff),
reinterpret_cast<uint32*>(o), width,
height);
} else {
jbyte* const u_buff = env->GetByteArrayElements(u, &inputCopy);
jbyte* const v_buff = env->GetByteArrayElements(v, &inputCopy);
ConvertYUV420ToARGB8888(
reinterpret_cast<uint8*>(y_buff), reinterpret_cast<uint8*>(u_buff),
reinterpret_cast<uint8*>(v_buff), reinterpret_cast<uint32*>(o), width,
height, y_row_stride, uv_row_stride, uv_pixel_stride);
env->ReleaseByteArrayElements(u, u_buff, JNI_ABORT);
env->ReleaseByteArrayElements(v, v_buff, JNI_ABORT);
}
env->ReleaseByteArrayElements(y, y_buff, JNI_ABORT);
env->ReleaseIntArrayElements(output, o, 0);
}
JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width,
jint height) {
jboolean inputCopy = JNI_FALSE;
jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
......
......@@ -27,6 +27,58 @@ limitations under the License.
// are normalized to eight bits.
static const int kMaxChannelValue = 262143;
static inline uint32 YUV2RGB(int nY, int nU, int nV) {
nY -= 16;
nU -= 128;
nV -= 128;
if (nY < 0) nY = 0;
// This is the floating point equivalent. We do the conversion in integer
// because some Android devices do not have floating point in hardware.
// nR = (int)(1.164 * nY + 2.018 * nU);
// nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
// nB = (int)(1.164 * nY + 1.596 * nV);
int nR = (int)(1192 * nY + 1634 * nV);
int nG = (int)(1192 * nY - 833 * nV - 400 * nU);
int nB = (int)(1192 * nY + 2066 * nU);
nR = MIN(kMaxChannelValue, MAX(0, nR));
nG = MIN(kMaxChannelValue, MAX(0, nG));
nB = MIN(kMaxChannelValue, MAX(0, nB));
nR = (nR >> 10) & 0xff;
nG = (nG >> 10) & 0xff;
nB = (nB >> 10) & 0xff;
return 0xff000000 | (nR << 16) | (nG << 8) | nB;
}
// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by
// separate u and v planes with arbitrary row and column strides,
// containing 8 bit 2x2 subsampled chroma samples.
// Converts to a packed ARGB 32 bit output of the same pixel dimensions.
void ConvertYUV420ToARGB8888(const uint8* const yData, const uint8* const uData,
const uint8* const vData, uint32* const output,
const int width, const int height,
const int y_row_stride, const int uv_row_stride,
const int uv_pixel_stride) {
uint32* out = output;
for (int y = 0; y < height; y++) {
const uint8* pY = yData + y_row_stride * y;
const int uv_row_start = uv_row_stride * (y >> 1);
const uint8* pU = uData + uv_row_start;
const uint8* pV = vData + uv_row_start;
for (int x = 0; x < width; x++) {
const int uv_offset = (x >> 1) * uv_pixel_stride;
*out++ = YUV2RGB(pY[x], pU[uv_offset], pV[uv_offset]);
}
}
}
// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an
// interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples,
// except the interleave order of U and V is reversed. Converts to a packed
......@@ -51,29 +103,7 @@ void ConvertYUV420SPToARGB8888(const uint8* const yData,
int nU = pUV[offset + 1];
#endif
nY -= 16;
nU -= 128;
nV -= 128;
if (nY < 0) nY = 0;
// This is the floating point equivalent. We do the conversion in integer
// because some Android devices do not have floating point in hardware.
// nR = (int)(1.164 * nY + 2.018 * nU);
// nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
// nB = (int)(1.164 * nY + 1.596 * nV);
int nR = (int)(1192 * nY + 1634 * nV);
int nG = (int)(1192 * nY - 833 * nV - 400 * nU);
int nB = (int)(1192 * nY + 2066 * nU);
nR = MIN(kMaxChannelValue, MAX(0, nR));
nG = MIN(kMaxChannelValue, MAX(0, nG));
nB = MIN(kMaxChannelValue, MAX(0, nB));
nR = (nR >> 10) & 0xff;
nG = (nG >> 10) & 0xff;
nB = (nB >> 10) & 0xff;
*out++ = 0xff000000 | (nR << 16) | (nG << 8) | nB;
*out++ = YUV2RGB(nY, nU, nV);
}
}
}
......@@ -101,23 +131,7 @@ void ConvertYUV420SPToARGB8888HalfSize(const uint8* const input,
int nU = *pUV++;
#endif
nY -= 16;
nU -= 128;
nV -= 128;
if (nY < 0) nY = 0;
int nR = (int)(1192 * nY + 1634 * nV);
int nG = (int)(1192 * nY - 833 * nV - 400 * nU);
int nB = (int)(1192 * nY + 2066 * nU);
nR = MIN(kMaxChannelValue, MAX(0, nR));
nG = MIN(kMaxChannelValue, MAX(0, nG));
nB = MIN(kMaxChannelValue, MAX(0, nB));
nR = (nR >> 10) & 0xff;
nG = (nG >> 10) & 0xff;
nB = (nB >> 10) & 0xff;
*out++ = 0xff000000 | (nR << 16) | (nG << 8) | nB;
*out++ = YUV2RGB(nY, nU, nV);
}
pY += stride;
}
......
......@@ -27,6 +27,12 @@ using namespace tensorflow;
extern "C" {
#endif
void ConvertYUV420ToARGB8888(const uint8* const yData, const uint8* const uData,
const uint8* const vData, uint32* const output,
const int width, const int height,
const int y_row_stride, const int uv_row_stride,
const int uv_pixel_stride);
// Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width
// and height. The input and output must already be allocated and non-null.
// For efficiency, no error checking is performed.
......
......@@ -24,12 +24,12 @@ import android.media.Image;
import android.media.Image.Plane;
import android.media.ImageReader;
import android.media.ImageReader.OnImageAvailableListener;
import junit.framework.Assert;
import org.tensorflow.demo.env.ImageUtils;
import org.tensorflow.demo.env.Logger;
import java.nio.ByteBuffer;
import java.util.List;
/**
......@@ -38,7 +38,7 @@ import java.util.List;
public class TensorflowImageListener implements OnImageAvailableListener {
private static final Logger LOGGER = new Logger();
private static final boolean SAVE_PREVIEW_BITMAP = false;
private static final boolean SAVE_PREVIEW_BITMAP = true;
private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb";
private static final String LABEL_FILE =
......@@ -55,7 +55,7 @@ public class TensorflowImageListener implements OnImageAvailableListener {
private int previewWidth = 0;
private int previewHeight = 0;
private byte[] yuvBytes = null;
private byte[][] yuvBytes;
private int[] rgbBytes = null;
private Bitmap rgbFrameBitmap = null;
private Bitmap croppedBitmap = null;
......@@ -68,44 +68,6 @@ public class TensorflowImageListener implements OnImageAvailableListener {
this.scoreView = scoreView;
}
private void readPlanesToYuvBuffer(final Plane[] planes, final byte[] yuvBytes) {
int position = 0;
// Copy the bytes from the Image into a buffer for easier conversion to RGB.
// TODO(andrewharp): Modify native code to accept multiple buffers so that
// only one pass is necessary during conversion to RGB.
final Plane yPlane = planes[0];
final ByteBuffer yBuffer = yPlane.getBuffer();
final int yRowStride = yPlane.getRowStride();
// Read the y (luminance buffer).
for (int row = 0; row < previewHeight; ++row) {
yBuffer.position(yRowStride * row);
// Pixel stride is guaranteed to be 1 so we can
// just do a copy operation.
yBuffer.get(yuvBytes, position, previewWidth);
position += previewWidth;
}
// Interleave the u and v buffers.
final ByteBuffer uBuffer = planes[1].getBuffer();
final ByteBuffer vBuffer = planes[2].getBuffer();
final int uvPixelStride = planes[1].getPixelStride();
final int uvWidth = previewWidth / 2;
final int uvHeight = previewHeight / 2;
Assert.assertEquals(
planes[1].getRowStride(), planes[2].getRowStride());
for (int y = 0; y < uvHeight; ++y) {
int readPos = planes[1].getRowStride() * y;
for (int x = 0; x < uvWidth; ++x) {
yuvBytes[position++] = vBuffer.get(readPos);
yuvBytes[position++] = uBuffer.get(readPos);
readPos += uvPixelStride;
}
}
}
private void drawResizedBitmap(final Bitmap src, final Bitmap dst) {
Assert.assertEquals(dst.getWidth(), dst.getHeight());
final float minDim = Math.min(src.getWidth(), src.getHeight());
......@@ -141,6 +103,8 @@ public class TensorflowImageListener implements OnImageAvailableListener {
return;
}
final Plane[] planes = image.getPlanes();
// Initialize the storage bitmaps once when the resolution is known.
if (previewWidth != image.getWidth() || previewHeight != image.getHeight()) {
previewWidth = image.getWidth();
......@@ -148,16 +112,35 @@ public class TensorflowImageListener implements OnImageAvailableListener {
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
rgbBytes = new int[previewWidth * previewHeight];
yuvBytes = new byte[ImageUtils.getYUVByteSize(previewWidth, previewHeight)];
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
yuvBytes = new byte[planes.length][];
for (int i = 0; i < planes.length; ++i) {
yuvBytes[i] = new byte[planes[i].getBuffer().capacity()];
}
}
readPlanesToYuvBuffer(image.getPlanes(), yuvBytes);
for (int i = 0; i < planes.length; ++i) {
planes[i].getBuffer().get(yuvBytes[i]);
}
image.close();
final int yRowStride = planes[0].getRowStride();
final int uvRowStride = planes[1].getRowStride();
final int uvPixelStride = planes[1].getPixelStride();
ImageUtils.convertYUV420ToARGB8888(
yuvBytes[0],
yuvBytes[1],
yuvBytes[2],
rgbBytes,
previewWidth,
previewHeight,
yRowStride,
uvRowStride,
uvPixelStride,
false);
ImageUtils.convertYUV420SPToARGB8888(yuvBytes, rgbBytes, previewWidth, previewHeight, false);
image.close();
} catch (final Exception e) {
if (image != null) {
image.close();
......
......@@ -87,6 +87,32 @@ public class ImageUtils {
public static native void convertYUV420SPToARGB8888(
byte[] input, int[] output, int width, int height, boolean halfSize);
/**
* Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width
* and height. The input and output must already be allocated and non-null.
* For efficiency, no error checking is performed.
*
* @param y
* @param u
* @param v
* @param uvPixelStride
* @param width The width of the input image.
* @param height The height of the input image.
* @param halfSize If true, downsample to 50% in each dimension, otherwise not.
* @param output A pre-allocated array for the ARGB 8:8:8:8 output data.
*/
public static native void convertYUV420ToARGB8888(
byte[] y,
byte[] u,
byte[] v,
int[] output,
int width,
int height,
int yRowStride,
int uvRowStride,
int uvPixelStride,
boolean halfSize);
/**
* Converts YUV420 semi-planar data to RGB 565 data using the supplied width
* and height. The input and output must already be allocated and non-null.
......
......@@ -402,7 +402,7 @@ deviations from the mean are dropped and re-picked.
- - -
### `tf.random_uniform(shape, minval=0.0, maxval=1.0, dtype=tf.float32, seed=None, name=None)` {#random_uniform}
### `tf.random_uniform(shape, minval=0, maxval=None, dtype=tf.float32, seed=None, name=None)` {#random_uniform}
Outputs random values from a uniform distribution.
......@@ -410,15 +410,24 @@ The generated values follow a uniform distribution in the range
`[minval, maxval)`. The lower bound `minval` is included in the range, while
the upper bound `maxval` is excluded.
For floats, the default range is `[0, 1)`. For ints, at least `maxval` must
be specified explicitly.
In the integer case, the random integers are slightly biased unless
`maxval - minval` is an exact power of two. The bias is small for values of
`maxval - minval` significantly smaller than the range of the output (either
`2**32` or `2**64`).
##### Args:
* <b>`shape`</b>: A 1-D integer Tensor or Python array. The shape of the output tensor.
* <b>`minval`</b>: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
range of random values to generate.
range of random values to generate. Defaults to 0.
* <b>`maxval`</b>: A 0-D Tensor or Python value of type `dtype`. The upper bound on
the range of random values to generate.
* <b>`dtype`</b>: The type of the output.
the range of random values to generate. Defaults to 1 if `dtype` is
floating point.
* <b>`dtype`</b>: The type of the output: `float32`, `float64`, `int32`, or `int64`.
* <b>`seed`</b>: A Python integer. Used to create a random seed for the distribution.
See
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
......@@ -429,6 +438,11 @@ the upper bound `maxval` is excluded.
A tensor of the specified shape filled with random uniform values.
##### Raises:
* <b>`ValueError`</b>: If `dtype` is integral and `maxval` is not specified.
- - -
......
......@@ -270,6 +270,7 @@
* [`conv2d`](../../api_docs/python/nn.md#conv2d)
* [`depthwise_conv2d`](../../api_docs/python/nn.md#depthwise_conv2d)
* [`dropout`](../../api_docs/python/nn.md#dropout)
* [`elu`](../../api_docs/python/nn.md#elu)
* [`embedding_lookup`](../../api_docs/python/nn.md#embedding_lookup)
* [`fixed_unigram_candidate_sampler`](../../api_docs/python/nn.md#fixed_unigram_candidate_sampler)
* [`in_top_k`](../../api_docs/python/nn.md#in_top_k)
......@@ -332,6 +333,8 @@
* [`Coordinator`](../../api_docs/python/train.md#Coordinator)
* [`exponential_decay`](../../api_docs/python/train.md#exponential_decay)
* [`ExponentialMovingAverage`](../../api_docs/python/train.md#ExponentialMovingAverage)
* [`FeatureList`](../../api_docs/python/train.md#FeatureList)
* [`FeatureLists`](../../api_docs/python/train.md#FeatureLists)
* [`FtrlOptimizer`](../../api_docs/python/train.md#FtrlOptimizer)
* [`global_norm`](../../api_docs/python/train.md#global_norm)
* [`global_step`](../../api_docs/python/train.md#global_step)
......
......@@ -10,9 +10,10 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
## Activation Functions
The activation ops provide different types of nonlinearities for use in neural
networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `softplus`,
and `softsign`), continuous but not everywhere differentiable functions (`relu`,
`relu6`, and `relu_x`), and random regularization (`dropout`).
networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`,
`softplus`, and `softsign`), continuous but not everywhere differentiable
functions (`relu`, `relu6`, and `relu_x`), and random regularization
(`dropout`).
All activation ops apply componentwise, and produce a tensor of the same
shape as the input tensor.
......@@ -52,6 +53,26 @@ Computes Rectified Linear 6: `min(max(features, 0), 6)`.
A `Tensor` with the same type as `features`.
- - -
### `tf.nn.elu(features, name=None)` {#elu}
Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise.
See [Fast and Aaccurate Deep Network Learning by Exponential Linear Units (ELUs)
](http://arxiv.org/abs/1511.07289)
##### Args:
* <b>`features`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A `Tensor`. Has the same type as `features`.
- - -
### `tf.nn.softplus(features, name=None)` {#softplus}
......
......@@ -946,7 +946,7 @@ assert v1 == v
Sharing a variable by capturing a scope and setting reuse:
```python
with tf.variable_scope("foo") as scope.
with tf.variable_scope("foo") as scope:
v = tf.get_variable("v", [1])
scope.reuse_variables()
v1 = tf.get_variable("v", [1])
......@@ -957,7 +957,7 @@ To prevent accidental sharing of variables, we raise an exception when
getting an existing variable in a non-reusing scope.
```python
with tf.variable_scope("foo") as scope.
with tf.variable_scope("foo"):
v = tf.get_variable("v", [1])
v1 = tf.get_variable("v", [1])
# Raises ValueError("... v already exists ...").
......
......@@ -1788,3 +1788,351 @@ tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')
* <b>`as_text`</b>: If `True`, writes the graph as an ASCII proto.
## Other Functions and Classes
- - -
### `class tf.train.FeatureList` {#FeatureList}
- - -
#### `tf.train.FeatureList.ByteSize()` {#FeatureList.ByteSize}
- - -
#### `tf.train.FeatureList.Clear()` {#FeatureList.Clear}
- - -
#### `tf.train.FeatureList.ClearExtension(extension_handle)` {#FeatureList.ClearExtension}
- - -
#### `tf.train.FeatureList.ClearField(field_name)` {#FeatureList.ClearField}
- - -
#### `tf.train.FeatureList.CopyFrom(other_msg)` {#FeatureList.CopyFrom}
Copies the content of the specified message into the current message.
The method clears the current message and then merges the specified
message using MergeFrom.
##### Args:
* <b>`other_msg`</b>: Message to copy into the current one.
- - -
#### `tf.train.FeatureList.FindInitializationErrors()` {#FeatureList.FindInitializationErrors}
Finds required fields which are not initialized.
##### Returns:
A list of strings. Each string is a path to an uninitialized field from
the top-level message, e.g. "foo.bar[5].baz".
- - -
#### `tf.train.FeatureList.FromString(s)` {#FeatureList.FromString}
- - -
#### `tf.train.FeatureList.HasExtension(extension_handle)` {#FeatureList.HasExtension}
- - -
#### `tf.train.FeatureList.HasField(field_name)` {#FeatureList.HasField}
- - -
#### `tf.train.FeatureList.IsInitialized(errors=None)` {#FeatureList.IsInitialized}
Checks if all required fields of a message are set.
##### Args:
* <b>`errors`</b>: A list which, if provided, will be populated with the field
paths of all missing required fields.
##### Returns:
True iff the specified message has all required fields set.
- - -
#### `tf.train.FeatureList.ListFields()` {#FeatureList.ListFields}
- - -
#### `tf.train.FeatureList.MergeFrom(msg)` {#FeatureList.MergeFrom}
- - -
#### `tf.train.FeatureList.MergeFromString(serialized)` {#FeatureList.MergeFromString}
- - -
#### `tf.train.FeatureList.ParseFromString(serialized)` {#FeatureList.ParseFromString}
Parse serialized protocol buffer data into this message.
Like MergeFromString(), except we clear the object first and
do not return the value that MergeFromString returns.
- - -
#### `tf.train.FeatureList.RegisterExtension(extension_handle)` {#FeatureList.RegisterExtension}
- - -
#### `tf.train.FeatureList.SerializePartialToString()` {#FeatureList.SerializePartialToString}
- - -
#### `tf.train.FeatureList.SerializeToString()` {#FeatureList.SerializeToString}
- - -
#### `tf.train.FeatureList.SetInParent()` {#FeatureList.SetInParent}
Sets the _cached_byte_size_dirty bit to true,
and propagates this to our listener iff this was a state change.
- - -
#### `tf.train.FeatureList.WhichOneof(oneof_name)` {#FeatureList.WhichOneof}
Returns the name of the currently set field inside a oneof, or None.
- - -
#### `tf.train.FeatureList.feature` {#FeatureList.feature}
Magic attribute generated for "feature" proto field.
- - -
### `class tf.train.FeatureLists` {#FeatureLists}
- - -
#### `tf.train.FeatureLists.ByteSize()` {#FeatureLists.ByteSize}
- - -
#### `tf.train.FeatureLists.Clear()` {#FeatureLists.Clear}
- - -
#### `tf.train.FeatureLists.ClearExtension(extension_handle)` {#FeatureLists.ClearExtension}
- - -
#### `tf.train.FeatureLists.ClearField(field_name)` {#FeatureLists.ClearField}
- - -
#### `tf.train.FeatureLists.CopyFrom(other_msg)` {#FeatureLists.CopyFrom}
Copies the content of the specified message into the current message.
The method clears the current message and then merges the specified
message using MergeFrom.
##### Args:
* <b>`other_msg`</b>: Message to copy into the current one.
- - -
#### `tf.train.FeatureLists.FindInitializationErrors()` {#FeatureLists.FindInitializationErrors}
Finds required fields which are not initialized.
##### Returns:
A list of strings. Each string is a path to an uninitialized field from
the top-level message, e.g. "foo.bar[5].baz".
- - -
#### `tf.train.FeatureLists.FromString(s)` {#FeatureLists.FromString}
- - -
#### `tf.train.FeatureLists.HasExtension(extension_handle)` {#FeatureLists.HasExtension}
- - -
#### `tf.train.FeatureLists.HasField(field_name)` {#FeatureLists.HasField}
- - -
#### `tf.train.FeatureLists.IsInitialized(errors=None)` {#FeatureLists.IsInitialized}
Checks if all required fields of a message are set.
##### Args:
* <b>`errors`</b>: A list which, if provided, will be populated with the field
paths of all missing required fields.
##### Returns:
True iff the specified message has all required fields set.
- - -
#### `tf.train.FeatureLists.ListFields()` {#FeatureLists.ListFields}
- - -
#### `tf.train.FeatureLists.MergeFrom(msg)` {#FeatureLists.MergeFrom}
- - -
#### `tf.train.FeatureLists.MergeFromString(serialized)` {#FeatureLists.MergeFromString}
- - -
#### `tf.train.FeatureLists.ParseFromString(serialized)` {#FeatureLists.ParseFromString}
Parse serialized protocol buffer data into this message.
Like MergeFromString(), except we clear the object first and
do not return the value that MergeFromString returns.
- - -
#### `tf.train.FeatureLists.RegisterExtension(extension_handle)` {#FeatureLists.RegisterExtension}
- - -
#### `tf.train.FeatureLists.SerializePartialToString()` {#FeatureLists.SerializePartialToString}
- - -
#### `tf.train.FeatureLists.SerializeToString()` {#FeatureLists.SerializeToString}
- - -
#### `tf.train.FeatureLists.SetInParent()` {#FeatureLists.SetInParent}
Sets the _cached_byte_size_dirty bit to true,
and propagates this to our listener iff this was a state change.
- - -
#### `tf.train.FeatureLists.WhichOneof(oneof_name)` {#FeatureLists.WhichOneof}
Returns the name of the currently set field inside a oneof, or None.
- - -
#### `tf.train.FeatureLists.feature_list` {#FeatureLists.feature_list}
Magic attribute generated for "feature_list" proto field.
......@@ -39,7 +39,7 @@ Python.
The packages that will be installed or upgraded during the pip install are listed in the
[REQUIRED_PACKAGES section of setup.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/setup.py)
Install pip if not already installed:
Install pip if it is not already installed:
```bash
# Ubuntu/Linux 64-bit
......@@ -118,8 +118,9 @@ $ source ~/tensorflow/bin/activate.csh # If using csh
With the Virtualenv environment activated, you can now
[test your installation](#test_install).
When you are done using TensorFlow, deactivate the environment.
```bash
# When you are done using TensorFlow, deactivate the environment.
(tensorflow)$ deactivate
$ # Your prompt should change back
......@@ -152,16 +153,15 @@ We provide 2 Docker images:
With Docker the installation is as follows:
* Install Docker on your machine.
* Create a [Docker
group](http://docs.docker.com/engine/installation/ubuntulinux/#create-a-docker-group)
to allow launching containers without `sudo`.
* Launch a Docker container with the TensorFlow image. The image
gets downloaded automatically on first launch.
See [installing Docker](http://docs.docker.com/engine/installation/) for instructions
on installing Docker on your machine.
Also create a [Docker
group](http://docs.docker.com/engine/installation/ubuntulinux/#create-a-docker-group)
to allow launching containers without `sudo`.
After Docker is installed, launch a Docker container with the TensorFlow binary
image as follows.
......@@ -169,7 +169,7 @@ image as follows.
$ docker run -it b.gcr.io/tensorflow/tensorflow
```
Within the Docker container, you can now [test your installation](#test_install).
You can now [test your installation](#test_install) within the Docker container.
You can alternatively launch the TensorFlow source image, for example if you want
to experiment directly with the source.
......@@ -196,7 +196,7 @@ export CUDA_HOME=/usr/local/cuda
### Run TensorFlow from the Command Line
See [common problems](#common_install_problems) if some error happens.
See [common problems](#common_install_problems) if an error happens.
Open a terminal and type the following:
......@@ -275,10 +275,10 @@ $ chmod +x PATH_TO_INSTALL.SH
$ ./PATH_TO_INSTALL.SH --user
```
Remember to replace `PATH_TO_INSTALL.SH` to point to the location where you
Remember to replace `PATH_TO_INSTALL.SH` with the location where you
downloaded the installer.
Finally, follow the instructions in that script to place bazel into your binary
Finally, follow the instructions in that script to place `bazel` into your binary
path.
#### Install other dependencies
......@@ -287,12 +287,26 @@ path.
$ sudo apt-get install python-numpy swig python-dev
```
#### Configure the installation {#configure}
Run the `configure` script at the root of the tree. The configure script
asks you for the path to your python interpreter and allows (optional)
configuration of the CUDA libraries (see [below](#configure_cuda)).
This step is used to locate the python and numpy header files.
```bash
$ ./configure
Please specify the location of python. [Default is /usr/bin/python]:
```
#### Optional: Install CUDA (GPUs on Linux) {#install_cuda}
In order to build or run TensorFlow with GPU support, both Cuda Toolkit 7.0 and
CUDNN 6.5 V2 from NVIDIA need to be installed.
TensorFlow GPU support requires having a GPU card with NVidia Compute Capability >= 3.5. Supported cards include but are not limited to:
TensorFlow GPU support requires having a GPU card with NVidia Compute Capability >= 3.5.
Supported cards include but are not limited to:
* NVidia Titan
* NVidia Titan X
......@@ -318,12 +332,14 @@ sudo cp cudnn-6.5-linux-x64-v2/cudnn.h /usr/local/cuda/include
sudo cp cudnn-6.5-linux-x64-v2/libcudnn* /usr/local/cuda/lib64
```
##### Configure TensorFlow's canonical view of Cuda libraries
From the root of your source tree, run:
##### Configure TensorFlow's canonical view of Cuda libraries {#configure_cuda}
When running the `configure` script from the root of your source tree, select
the option `Y` when asked to build TensorFlow with GPU support.
``` bash
$ ./configure
Do you wish to build TensorFlow with GPU support? [y/n] y
Please specify the location of python. [Default is /usr/bin/python]:
Do you wish to build TensorFlow with GPU support? [y/N] y
GPU support will be enabled for TensorFlow
Please specify the location where CUDA 7.0 toolkit is installed. Refer to
......@@ -400,9 +416,9 @@ given necessary bazel new feature support.
### Installation for Mac OS X
Mac needs the same set of dependencies as Linux, however installing those
dependencies is different. Here is a set of useful links to help with installing
the dependencies on Mac OS X :
Mac needs the same set of dependencies as Linux, but the installation
process for those dependencies is different. Here is a set of useful links
to help with installing the dependencies on Mac OS X :
#### Bazel
......@@ -420,6 +436,18 @@ Notes : You need to install
Follow installation instructions [here](http://docs.scipy.org/doc/numpy/user/install.html).
#### Configure the installation {#configure_osx}
Run the `configure` script at the root of the tree. The configure script
asks you for the path to your python interpreter.
This step is used to locate the python and numpy header files.
```bash
$ ./configure
Please specify the location of python. [Default is /usr/bin/python]:
Do you wish to build TensorFlow with GPU support? [y/N]
```
### Create the pip package and install {#create-pip}
......@@ -505,7 +533,7 @@ SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
Solution: Download the wheel manually via curl or wget, and pip install locally.
### On Linux
### Linux issues
If you encounter:
......@@ -533,7 +561,7 @@ Solution: TensorFlow depends on protobuf, which requires the Python package
You can resolve the issue in one of the following ways:
* Upgrade the Python installation with the current version `six`:
* Upgrade the Python installation with the current version of `six`:
```bash
$ sudo easy_install -U six
......
......@@ -14,16 +14,6 @@ load("/tensorflow/tensorflow", "cuda_py_tests")
load("/tensorflow/tensorflow", "tf_py_wrap_cc")
load("/tensorflow/core/platform/default/build_config", "tf_proto_library_py")
config_setting(
name = "macosx",
values = {"cpu": "darwin"},
)
numpy_macosx_include_dir = select({
":macosx": ["-I/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/numpy/core/include"],
"//conditions:default": [],
})
py_library(
name = "python",
srcs = ["__init__.py"],
......@@ -467,6 +457,7 @@ tf_gen_op_wrapper_py(
"MaxPoolGradWithArgmax",
"ReluGrad",
"Relu6Grad",
"EluGrad",
"SoftplusGrad",
"SoftsignGrad",
"BiasAdd",
......@@ -712,7 +703,6 @@ tf_cuda_library(
name = "tf_session_helper",
srcs = ["client/tf_session_helper.cc"],
hdrs = ["client/tf_session_helper.h"],
copts = numpy_macosx_include_dir + ["-I/usr/include/python2.7"],
deps = [
":construction_fails_op",
":test_kernel_label_op_kernel",
......@@ -721,13 +711,14 @@ tf_cuda_library(
"//tensorflow/core:kernels",
"//tensorflow/core:lib",
"//tensorflow/core:protos_cc",
"//third_party/py/numpy:headers",
"//util/python:python_headers",
],
)
tf_py_wrap_cc(
name = "client/pywraptensorflow_server_lib",
srcs = ["client/tensorflow_server.i"],
copts = numpy_macosx_include_dir,
swig_includes = [
"lib/core/status.i",
"lib/core/strings.i",
......@@ -737,13 +728,13 @@ tf_py_wrap_cc(
"//tensorflow/core",
"//tensorflow/core:lib",
"//tensorflow/core:protos_cc",
"//util/python:python_headers",
],
)
tf_py_wrap_cc(
name = "pywrap_tensorflow",
srcs = ["tensorflow.i"],
copts = numpy_macosx_include_dir,
swig_includes = [
"client/events_writer.i",
"client/tf_session.i",
......@@ -760,6 +751,7 @@ tf_py_wrap_cc(
":py_record_reader_lib",
":py_record_writer_lib",
":tf_session_helper",
"//util/python:python_headers",
],
)
......
......@@ -276,6 +276,10 @@ class ConcatOpTest(tf.test.TestCase):
concat = tf.concat(dim, [p1, c1, p2, c2])
self.assertEqual(4, concat.get_shape().ndims)
# All dimensions unknown.
concat2 = tf.concat(dim, [p1, p2])
self.assertEqual(None, concat2.get_shape())
# Rank doesn't match.
c3 = tf.constant(30.0, shape=[4, 4, 4])
with self.assertRaises(ValueError):
......
......@@ -33,6 +33,10 @@ features = lambda d: tf.train.Features(feature=d)
bytes_feature = lambda v: feature(bytes_list=tf.train.BytesList(value=v))
int64_feature = lambda v: feature(int64_list=tf.train.Int64List(value=v))
float_feature = lambda v: feature(float_list=tf.train.FloatList(value=v))
# Helpers for creating SequenceExample objects
feature_list = lambda l: tf.train.FeatureList(feature=l)
feature_lists = lambda d: tf.train.FeatureLists(feature_list=d)
sequence_example = tf.train.SequenceExample
def flatten(list_of_lists):
......@@ -475,5 +479,24 @@ class ParseSingleExampleTest(tf.test.TestCase):
}, expected_output)
class ParseSequenceExampleTest(tf.test.TestCase):
def testCreateSequenceExample(self):
value = sequence_example(
context=features({
"global_feature": float_feature([1, 2, 3]),
}),
feature_lists=feature_lists({
"repeated_feature_2_frames": feature_list([
bytes_feature(["a", "b", "c"]),
bytes_feature(["a", "d", "e"])]),
"repeated_feature_3_frames": feature_list([
int64_feature([3, 4, 5, 6, 7]),
int64_feature([-1, 0, 0, 0, 0]),
int64_feature([1, 2, 3, 4, 5])])
}))
value.SerializeToString() # Smoke test
if __name__ == "__main__":
tf.test.main()
......@@ -45,18 +45,18 @@ class ReluTest(tf.test.TestCase):
self.assertShapeEqual(np_relu, relu)
def testNumbers(self):
for t in [np.int32, np.int64, np.float, np.double]:
for t in [np.int32, np.int64, np.float32, np.float64]:
self._testRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=False)
if t in [np.float, np.double]:
if t in [np.float32, np.float64]:
self._testRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=True)
# The gradient test for ReLU is a bit tricky as the derivative is not well
# defined at around zero and we want to avoid that in terms of input values.
def testGradientFloat(self):
def testGradientFloat32(self):
with self.test_session():
x = tf.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
......@@ -70,7 +70,7 @@ class ReluTest(tf.test.TestCase):
y,
[2, 5],
x_init_value=x_init)
print("relu (float) gradient err = ", err)
print("relu (float32) gradient err = ", err)
self.assertLess(err, 1e-4)
def testGradientNaN(self):
......@@ -91,7 +91,7 @@ class ReluTest(tf.test.TestCase):
except Exception as e: # pylint: disable=broad-except
assert "ReluGrad input is not finite." in str(e)
def testGradientDouble(self):
def testGradientFloat64(self):
with self.test_session():
x = tf.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
......@@ -105,10 +105,10 @@ class ReluTest(tf.test.TestCase):
y,
[2, 5],
x_init_value=x_init)
print("relu (double) gradient err = ", err)
print("relu (float64) gradient err = ", err)
self.assertLess(err, 1e-10)
def testGradGradFloat(self):
def testGradGradFloat32(self):
with self.test_session():
x = tf.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
......@@ -123,10 +123,10 @@ class ReluTest(tf.test.TestCase):
z[0],
[2, 5],
x_init_value=x_init)
print("relu (float) gradient of gradient err = ", err)
print("relu (float32) gradient of gradient err = ", err)
self.assertLess(err, 1e-4)
def testGradGradDouble(self):
def testGradGradFloat64(self):
with self.test_session():
x = tf.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
......@@ -141,7 +141,7 @@ class ReluTest(tf.test.TestCase):
z[0],
[2, 5],
x_init_value=x_init)
print("relu (double) gradient of gradient err = ", err)
print("relu (float64) gradient of gradient err = ", err)
self.assertLess(err, 1e-10)
......@@ -169,7 +169,7 @@ class Relu6Test(tf.test.TestCase):
self.assertShapeEqual(np_relu6, relu6)
def testNumbers(self):
for t in [np.int32, np.int64, np.float, np.double]:
for t in [np.int32, np.int64, np.float32, np.float64]:
self._testRelu6(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=False)
......@@ -181,7 +181,7 @@ class Relu6Test(tf.test.TestCase):
# The gradient test for ReLU6 is a bit tricky as the derivative is
# not well defined at around zero and six and we want to avoid that
# in terms of input values.
def testGradientFloat(self):
def testGradientFloat32(self):
with self.test_session():
x = tf.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
......@@ -195,10 +195,10 @@ class Relu6Test(tf.test.TestCase):
y,
[2, 5],
x_init_value=x_init)
print("relu6 (float) gradient err = ", err)
print("relu6 (float32) gradient err = ", err)
self.assertLess(err, 1e-4)
def testGradientDouble(self):
def testGradientFloat64(self):
with self.test_session():
x = tf.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
......@@ -212,9 +212,67 @@ class Relu6Test(tf.test.TestCase):
y,
[2, 5],
x_init_value=x_init)
print("relu6 (double) gradient err = ", err)
print("relu6 (float64) gradient err = ", err)
self.assertLess(err, 1e-10)
class EluTest(tf.test.TestCase):
def _npElu(self, np_features):
return np.where(np_features < 0, np.exp(np_features) - 1, np_features)
def testNpElu(self):
self.assertAllClose(
np.array([[-0.59343034025, 0.7, -0.39346934028, 0.3, -0.09516258196],
[0.1, -0.25918177931, 0.5, -0.5034146962, 0.9]]),
self._npElu(np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -
0.7, 0.9]])))
def _testElu(self, np_features, use_gpu=False):
np_elu = self._npElu(np_features)
with self.test_session(use_gpu=use_gpu):
elu = tf.nn.elu(np_features)
tf_elu = elu.eval()
self.assertAllClose(np_elu, tf_elu)
self.assertShapeEqual(np_elu, elu)
def testNumbers(self):
for t in [np.float32, np.float64]:
self._testElu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=False)
self._testElu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=True)
def testGradientFloat32(self):
with self.test_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = tf.constant(x_val, name="x")
y = tf.nn.elu(x, name="elu")
x_init = np.asarray(x_val, dtype=np.float32, order="F")
err = tf.test.compute_gradient_error(x,
[2, 5],
y,
[2, 5],
x_init_value=x_init)
print("elu (float32) gradient err = ", err)
self.assertLess(err, 1e-4)
def testGradientFloat64(self):
with self.test_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = tf.constant(x_val, dtype=tf.float64, name="x")
y = tf.nn.elu(x, name="elu")
x_init = np.asarray(x_val, dtype=np.float64, order="F")
err = tf.test.compute_gradient_error(x,
[2, 5],
y,
[2, 5],
x_init_value=x_init)
print("elu (float64) gradient err = ", err)
self.assertLess(err, 1e-6)
if __name__ == "__main__":
tf.test.main()
......@@ -106,7 +106,7 @@ class VariableStoreTest(tf.test.TestCase):
with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
self.assertEqual(tower_shared.name, "tower")
with tf.name_scope("scope") as sc:
self.assertEqual(sc, "foo_1/scope/")
self.assertEqual(sc, "foo_1/tower/scope/")
def testVarScopeNameScope(self):
with self.test_session():
......@@ -124,7 +124,65 @@ class VariableStoreTest(tf.test.TestCase):
self.assertEqual(sc2, "scope3/tower/scope2/")
with variable_scope.variable_scope(tower):
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "scope3/scope2/")
self.assertEqual(sc2, "scope3/tower_1/scope2/")
root_var_scope = variable_scope.get_variable_scope()
with tf.name_scope("scope4"):
with variable_scope.variable_scope(root_var_scope):
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "scope4/scope2/")
def testVarOpScope(self):
with self.test_session():
with tf.name_scope("scope1"):
with variable_scope.variable_op_scope([], "tower", "default"):
self.assertEqual(variable_scope.get_variable("w", []).name,
"tower/w:0")
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "scope1/tower/scope2/")
with variable_scope.variable_op_scope([], "tower", "default"):
with self.assertRaises(ValueError):
variable_scope.get_variable("w", [])
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "scope1/tower_1/scope2/")
with tf.name_scope("scope2"):
with variable_scope.variable_op_scope([], None, "default"):
self.assertEqual(variable_scope.get_variable("w", []).name,
"default/w:0")
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "scope2/default/scope2/")
with variable_scope.variable_op_scope([], None, "default"):
self.assertEqual(variable_scope.get_variable("w", []).name,
"default_1/w:0")
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "scope2/default_1/scope2/")
def testVarOpScopeReuse(self):
with self.test_session():
with tf.variable_scope("outer") as outer:
with variable_scope.variable_op_scope([], "tower", "default"):
self.assertEqual(variable_scope.get_variable("w", []).name,
"outer/tower/w:0")
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "outer/tower/scope2/")
with variable_scope.variable_op_scope([], None, "default"):
self.assertEqual(variable_scope.get_variable("w", []).name,
"outer/default/w:0")
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "outer/default/scope2/")
with tf.variable_scope(outer, reuse=True) as outer:
with variable_scope.variable_op_scope([], "tower", "default"):
self.assertEqual(variable_scope.get_variable("w", []).name,
"outer/tower/w:0")
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "outer_1/tower/scope2/")
with variable_scope.variable_op_scope([], None, "default"):
self.assertEqual(variable_scope.get_variable("w", []).name,
"outer/default/w:0")
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "outer_1/default/scope2/")
def testVarScopeGetVar(self):
with self.test_session():
......
......@@ -50,6 +50,29 @@ class XentTest(tf.test.TestCase):
self._testXent(features, labels, use_gpu=False)
self._testXent(features, labels, use_gpu=True)
def _testSingleClass(self, use_gpu=False):
with self.test_session(use_gpu=use_gpu) as sess:
loss = tf.nn.softmax_cross_entropy_with_logits(
np.array([[1.], [-1.], [0.]]).astype(np.float32),
np.array([[-1.], [0.], [1.]]).astype(np.float32))
backprop = loss.op.outputs[1]
tf_loss, tf_backprop = sess.run([loss, backprop])
self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
self.assertAllClose([[2.0], [1.0], [0.0]], tf_backprop)
def testSingleClass(self):
self._testSingleClass(True)
self._testSingleClass(False)
def testRankTooLarge(self):
np_features = np.array(
[[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(np.float32)
np_labels = np.array(
[[[0., 0., 0., 1.]], [[0., .5, .5, 0.]]]).astype(np.float32)
self.assertRaisesRegexp(
ValueError, "must have the same rank",
tf.nn.softmax_cross_entropy_with_logits, np_features, np_labels)
def testNpXent(self):
# We create 2 batches of logits for testing.
# batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3.
......
......@@ -335,7 +335,10 @@ def _ConcatShape(op):
value.get_shape().assert_has_rank(rank)
else:
rank = value.get_shape().ndims
return [tensor_shape.unknown_shape(ndims=max(rank, 1))]
# TODO(irving): Remove once !kAllowLegacyScalars.
if rank is not None:
rank = max(rank, 1)
return [tensor_shape.unknown_shape(ndims=rank)]
else:
# Merge all the non-concat dims, and sum the concat dim to make an
......
......@@ -43,7 +43,7 @@ The convenience function [`resize_images()`](#resize_images) supports both 4-D
and 3-D tensors as input and output. 4-D tensors are for batches of images,
3-D tensors for individual images.
Other resizing Ops only support 3-D individual images as input:
Other resizing Ops only support 4-D batches of images as input:
[`resize_area`](#resize_area), [`resize_bicubic`](#resize_bicubic),
[`resize_bilinear`](#resize_bilinear),
[`resize_nearest_neighbor`](#resize_nearest_neighbor).
......@@ -51,9 +51,9 @@ Other resizing Ops only support 3-D individual images as input:
Example:
```python
# Decode a JPG image and resize it to 299 by 299.
# Decode a JPG image and resize it to 299 by 299 using default method.
image = tf.image.decode_jpeg(...)
resized_image = tf.image.resize_bilinear(image, [299, 299])
resized_image = tf.image.resize_images(image, 299, 299)
```
@@resize_images
......@@ -528,7 +528,7 @@ def resize_images(images, new_height, new_width, method=ResizeMethod.BILINEAR):
raise ValueError('Resize method is not implemented.')
if not is_batch:
images = array_ops.reshape(images, [new_height, new_width, depth])
images = array_ops.squeeze(images, squeeze_dims=[0])
return images
......
......@@ -17,15 +17,17 @@
"""## Activation Functions
The activation ops provide different types of nonlinearities for use in neural
networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `softplus`,
and `softsign`), continuous but not everywhere differentiable functions (`relu`,
`relu6`, and `relu_x`), and random regularization (`dropout`).
networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`,
`softplus`, and `softsign`), continuous but not everywhere differentiable
functions (`relu`, `relu6`, and `relu_x`), and random regularization
(`dropout`).
All activation ops apply componentwise, and produce a tensor of the same
shape as the input tensor.
@@relu
@@relu6
@@elu
@@softplus
@@softsign
@@dropout
......
......@@ -132,6 +132,11 @@ def _Relu6Grad(op, grad):
return gen_nn_ops._relu6_grad(grad, op.inputs[0])
@ops.RegisterGradient("Elu")
def _EluGrad(op, grad):
return gen_nn_ops._elu_grad(grad, op.outputs[0])
@ops.RegisterGradient("Softplus")
def _SoftplusGrad(op, grad):
return gen_nn_ops._softplus_grad(grad, op.inputs[0])
......
......@@ -237,12 +237,14 @@ def max_pool(value, ksize, strides, padding, name=None):
ops.RegisterShape("Relu")(common_shapes.unchanged_shape)
ops.RegisterShape("Relu6")(common_shapes.unchanged_shape)
ops.RegisterShape("Elu")(common_shapes.unchanged_shape)
ops.RegisterShape("Softplus")(common_shapes.unchanged_shape)
ops.RegisterShape("Softsign")(common_shapes.unchanged_shape)
@ops.RegisterShape("ReluGrad")
@ops.RegisterShape("Relu6Grad")
@ops.RegisterShape("EluGrad")
@ops.RegisterShape("SoftplusGrad")
@ops.RegisterShape("SoftsignGrad")
def _BinaryElementwiseShape(op):
......
......@@ -147,7 +147,7 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
"""RNN decoder with embedding and a pure-decoding option.
Args:
decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs).
decoder_inputs: a list of 1D batch-sized int32 Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
cell: rnn_cell.RNNCell defining the cell function.
num_symbols: integer, how many symbols come into the embedding.
......@@ -219,8 +219,8 @@ def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
encoder state, on embedded decoder_inputs.
Args:
encoder_inputs: a list of 1D int32-Tensors of shape [batch_size].
decoder_inputs: a list of 1D int32-Tensors of shape [batch_size].
encoder_inputs: a list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: a list of 1D int32 Tensors of shape [batch_size].
cell: rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: integer; number of symbols on the encoder side.
num_decoder_symbols: integer; number of symbols on the decoder side.
......@@ -286,8 +286,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
encoder state, on embedded decoder_inputs.
Args:
encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
encoder_inputs: a list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: a list of 1D int32 Tensors of shape [batch_size].
cell: rnn_cell.RNNCell defining the cell function and size.
num_symbols: integer; number of symbols for both encoder and decoder.
output_projection: None or a pair (W, B) of output projection weights and
......@@ -486,7 +486,7 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
"""RNN decoder with embedding and attention and a pure-decoding option.
Args:
decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs).
decoder_inputs: a list of 1D batch-sized int32 Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
cell: rnn_cell.RNNCell defining the cell function.
......@@ -566,8 +566,8 @@ def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell,
encoder state, on embedded decoder_inputs and attending to encoder outputs.
Args:
encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
encoder_inputs: a list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: a list of 1D int32 Tensors of shape [batch_size].
cell: rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: integer; number of symbols on the encoder side.
num_decoder_symbols: integer; number of symbols on the decoder side.
......@@ -636,7 +636,7 @@ def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols,
Args:
logits: list of 2D Tensors of shape [batch_size x num_decoder_symbols].
targets: list of 1D batch-sized int32-Tensors of the same length as logits.
targets: list of 1D batch-sized int32 Tensors of the same length as logits.
weights: list of 1D batch-sized float-Tensors of the same length as logits.
num_decoder_symbols: integer, number of decoder symbols (output classes).
average_across_timesteps: If set, divide the returned cost by the total
......@@ -692,7 +692,7 @@ def sequence_loss(logits, targets, weights, num_decoder_symbols,
Args:
logits: list of 2D Tensors os shape [batch_size x num_decoder_symbols].
targets: list of 1D batch-sized int32-Tensors of the same length as logits.
targets: list of 1D batch-sized int32 Tensors of the same length as logits.
weights: list of 1D batch-sized float-Tensors of the same length as logits.
num_decoder_symbols: integer, number of decoder symbols (output classes).
average_across_timesteps: If set, divide the returned cost by the total
......@@ -731,7 +731,7 @@ def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights,
Args:
encoder_inputs: a list of Tensors to feed the encoder; first seq2seq input.
decoder_inputs: a list of Tensors to feed the decoder; second seq2seq input.
targets: a list of 1D batch-sized int32-Tensors (desired output sequence).
targets: a list of 1D batch-sized int32 Tensors (desired output sequence).
weights: list of 1D batch-sized float-Tensors to weight the targets.
buckets: a list of pairs of (input size, output size) for each bucket.
num_decoder_symbols: integer, number of decoder symbols (output classes).
......
......@@ -45,6 +45,7 @@ create variables contingent on certain conditions.
@@get_variable
@@get_variable_scope
@@variable_op_scope
@@variable_scope
@@constant_initializer
......
......@@ -131,12 +131,14 @@ class _VariableScope(object):
name: name of the current scope, used as prefix in get_variable.
initializer: default initializer passed to get_variable.
reuse: Boolean or None, setting the reuse in get_variable.
name_scope: The name passed to tf.name_scope.
"""
def __init__(self, reuse, name="", initializer=None):
def __init__(self, reuse, name="", initializer=None, name_scope=""):
self._name = name
self._initializer = initializer
self._reuse = reuse
self._name_scope = name_scope
@property
def name(self):
......@@ -238,6 +240,60 @@ def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
trainable, collections)
@contextlib.contextmanager
def _pure_variable_scope(name_or_scope, reuse=None, initializer=None):
"""Creates a context for the variable_scope, see `variable_scope` for docs.
Note: this does not create a name scope.
Args:
name_or_scope: `string` or `VariableScope`: the scope to open.
reuse: `True` or `None`; if `True`, we go into reuse mode for this scope as
well as all sub-scopes; if `None`, we just inherit the parent scope reuse.
initializer: default initializer for variables within this scope.
Yields:
A scope that can be to captured and reused.
Raises:
ValueError: when trying to reuse within a create scope, or create within
a reuse scope, or if reuse is not `None` or `True`.
TypeError: when the types of some arguments are not appropriate.
"""
get_variable_scope() # Ensure that a default exists, then get a pointer.
default_varscope = ops.get_collection(_VARSCOPE_KEY)
try:
old = default_varscope[0]
reuse = reuse or old.reuse # Re-using is inherited by sub-scopes.
if isinstance(name_or_scope, _VariableScope):
name_scope = name_or_scope._name_scope # pylint: disable=protected-access
# Handler for the case when we jump to a shared scope.
# We create a new VariableScope (default_varscope[0]) that contains
# a copy of the provided shared scope, possibly with changed reuse
# and initializer, if the user requested this.
default_varscope[0] = _VariableScope(reuse, name_or_scope.name,
name_or_scope.initializer,
name_scope)
if initializer:
default_varscope[0].set_initializer(initializer)
yield default_varscope[0]
else:
# Handler for the case when we just prolong current variable scope.
# VariableScope with name extended by the provided one, and inherited
# reuse and initializer (except if the user provided values to set).
new_name = old.name + "/" + name_or_scope if old.name else name_or_scope
default_varscope[0] = _VariableScope(reuse, name=new_name,
initializer=old.initializer,
name_scope=name_or_scope)
if initializer:
default_varscope[0].set_initializer(initializer)
yield default_varscope[0]
finally:
default_varscope[0] = old
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
def variable_scope(name_or_scope, reuse=None, initializer=None):
"""Returns a context for variable scope.
......@@ -304,7 +360,7 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
well as all sub-scopes; if `None`, we just inherit the parent scope reuse.
initializer: default initializer for variables within this scope.
Yields:
Returns:
A scope that can be to captured and reused.
Raises:
......@@ -315,39 +371,76 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
if not isinstance(name_or_scope, (_VariableScope,) + six.string_types):
raise TypeError("VariableScope: name_scope must be a string or "
"VariableScope.")
if reuse not in [None, True]:
raise ValueError("VariableScope reuse parameter must be True or None.")
if not reuse and isinstance(name_or_scope, (_VariableScope)):
logging.info("Passing VariableScope to a non-reusing scope, intended?")
if reuse and isinstance(name_or_scope, six.string_types):
logging.info("Re-using string-named scope, consider capturing as object.")
get_variable_scope() # Ensure that a default exists, then get a pointer.
default_varscope = ops.get_collection(_VARSCOPE_KEY)
try:
old = default_varscope[0]
reuse = reuse or old.reuse # Re-using is inherited by sub-scopes.
if isinstance(name_or_scope, _VariableScope):
# Handler for the case when we jump to a shared scope.
# In this case, we leave the current name_scope unchanged.
# We create a new VariableScope (default_varscope[0]) that contains
# a copy of the provided shared scope, possibly with changed reuse
# and initializer, if the user requested this.
default_varscope[0] = _VariableScope(reuse, name_or_scope.name,
name_or_scope.initializer)
if initializer:
default_varscope[0].set_initializer(initializer)
yield default_varscope[0]
if isinstance(name_or_scope, six.string_types):
name = name_or_scope
else:
name = name_or_scope._name_scope # pylint: disable=protected-access
if name:
with ops.name_scope(name), _pure_variable_scope(
name_or_scope, reuse, initializer) as vs:
yield vs
else:
# This can only happen if someone is entering the root variable scope.
with _pure_variable_scope(name_or_scope, reuse, initializer) as vs:
yield vs
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
def variable_op_scope(values, name, default_name, initializer=None):
"""Returns a context manager for defining an op that creates variables.
This context manager validates that the given `values` are from the
same graph, ensures that that graph is the default graph, and pushes a
name scope and a variable scope.
If `name` is not None, it is used as is in the variable scope. If `name`
is None, then `default_name` is used. In that case, if the same name has been
previously used in the same scope, it will made unique be appending `_N` to
it.
This is intended to be used when defining generic ops and so reuse is always
inherited.
For example, to define a new Python op called `my_op_with_vars`:
```python
def my_op_with_vars(a, b, name=None):
with tf.variable_op_scope([a, b], name, "MyOp") as scope:
a = tf.convert_to_tensor(a, name="a")
b = tf.convert_to_tensor(b, name="b")
c = tf.get_variable('c')
# Define some computation that uses `a`, `b`, and `c`.
return foo_op(..., name=scope)
```
Args:
values: The list of `Tensor` arguments that are passed to the op function.
name: The name argument that is passed to the op function, this name is not
uniquified in the variable scope.
default_name: The default name to use if the `name` argument is `None`, this
name will be uniquified.
initializer: A default initializer to pass to variable scope.
Returns:
A context manager for use in defining a Python op.
Raises:
ValueError: when trying to reuse within a create scope, or create within
a reuse scope, or if reuse is not `None` or `True`.
TypeError: when the types of some arguments are not appropriate.
"""
if default_name is None:
raise TypeError("default_name cannot be None")
g = ops._get_graph_from_inputs(values) # pylint: disable=protected-access
with g.as_default():
if name:
with variable_scope(name, initializer=initializer) as vs:
yield vs
else:
# Handler for the case when we just prolong current variable scope.
# In this case we prolong the current name_scope and create a new
# VariableScope with name extended by the provided one, and inherited
# reuse and initializer (except if the user provided values to set).
with ops.name_scope(name_or_scope):
new_name = old.name + "/" + name_or_scope if old.name else name_or_scope
default_varscope[0] = _VariableScope(reuse, name=new_name,
initializer=old.initializer)
if initializer:
default_varscope[0].set_initializer(initializer)
yield default_varscope[0]
finally:
default_varscope[0] = old
with ops.name_scope(default_name) as scope:
count = len(default_name.split("/"))
scoped_name = "/".join(scope.split("/")[-count - 1:-1])
with _pure_variable_scope(scoped_name,
initializer=initializer) as vs:
yield vs
......@@ -13,3 +13,4 @@ components/tf-graph-common/lib/scene/*.js
components/tf-event-dashboard/*.js
components/tf-categorizer/*.js
components/tf-dashboard-common/*.js
components/**/test/*.js
// hackhack for some reason getting graphlib via an import reference results in
// out of order script evaluation
<script src="../../lodash/lodash.min.js"></script>
<script src="../../graphlib/dist/graphlib.core.js"></script>
<script src="../../dagre/dist/dagre.core.js"></script>
<script src="../../lodash/lodash.min.js"></script>
<script src="../../graphlib/dist/graphlib.core.js"></script>
......@@ -77,6 +77,32 @@ paper-progress {
--paper-progress-height: 6px;
--paper-progress-active-color: #f3913e;
}
.context-menu {
position: absolute;
display: none;
background-color: #e2e2e2;
border-radius: 2px;
font-size: 14px;
min-width: 150px;
border: 1px solid #d4d4d4;
}
/deep/ .context-menu ul {
list-style-type: none;
margin: 0;
padding: 0;
cursor: default;
}
/deep/ .context-menu ul li {
padding: 4px 16px;
}
/deep/ .context-menu ul li:hover {
background-color: #f3913e;
color: white;
}
</style>
<template is="dom-if" if="[[_isNotComplete(progress)]]">
<div id="progress-bar">
......@@ -104,11 +130,13 @@ paper-progress {
render-hierarchy="[[_renderHierarchy]]"
graph="[[graph]]"
selected-node="{{_selectedNode}}"
selected-node-include="{{_selectedNodeInclude}}"
highlighted-node="{{_highlightedNode}}"
color-by="[[colorBy]]"
color-by-params="[[colorByParams]]"
></tf-graph-info>
</div>
<div class="context-menu"></div>
</div>
</template>
</dom-module>
......@@ -137,9 +165,17 @@ Polymer({
},
// Private API: Data routing between child components.
_selectedNode: String,
// The enum value of the include property of the selected node.
_selectedNodeInclude: Number,
_highlightedNode: String,
_renderHierarchy: Object,
},
listeners: {
'node-toggle-extract': '_nodeToggleExtract'
},
observers: [
'_updateNodeInclude(_selectedNode)'
],
/** True if the progress is not complete yet (< 100 %). */
_isNotComplete: function(progress) {
return progress.value < 100;
......@@ -153,6 +189,14 @@ Polymer({
result += ' loading';
}
return result;
},
_updateNodeInclude: function(nodeName) {
var node = this.graphHierarchy.node(nodeName);
this.set("_selectedNodeInclude",
node ? node.include : tf.graph.InclusionType.UNSPECIFIED);
},
_nodeToggleExtract: function() {
this._updateNodeInclude(this._selectedNode);
}
});
</script>
......@@ -29,6 +29,9 @@ export enum GraphType {FULL, EMBEDDED, META, SERIES, CORE, SHADOW, BRIDGE,
EDGE};
export enum NodeType {META, OP, SERIES, BRIDGE, ELLIPSIS};
/** Indicates if a node is to be included in the main graph when rendered. */
export enum InclusionType {INCLUDE, EXCLUDE, UNSPECIFIED};
/**
* A BaseEdge is the label object (in the graphlib sense) for an edge in the
* original, full graph produced after parsing. Subsequent graphs, like those
......@@ -98,6 +101,12 @@ export interface Node {
parentNode: Node;
/** Runtime execution stats for this node, if available */
stats: NodeStats;
/** If the node is to be included or excluded from the main graph when
* rendered. Defaults to UNSPECIFIED, which means that the rendering
* algorithm determines if it will be included or not. Then can be set to
* INCLUDE or EXCLUDE manually by the user.
*/
include: InclusionType;
}
export interface OpNode extends Node {
......@@ -258,6 +267,7 @@ export class EllipsisNodeImpl implements EllipsisNode {
isGroupNode: boolean;
cardinality: number;
parentNode: Node;
include: InclusionType;
/**
* Constructs a new ellipsis annotation node.
......@@ -271,6 +281,7 @@ export class EllipsisNodeImpl implements EllipsisNode {
this.parentNode = null;
this.stats = null;
this.setNumMoreNodes(numNodes);
this.include = InclusionType.UNSPECIFIED;
}
setNumMoreNodes(numNodes: number) {
......@@ -296,6 +307,7 @@ class OpNodeImpl implements OpNode {
inEmbeddings: OpNode[];
outEmbeddings: OpNode[];
parentNode: Node;
include: InclusionType;
/**
* Constructs a new Op node.
......@@ -319,6 +331,7 @@ class OpNodeImpl implements OpNode {
this.inEmbeddings = [];
this.outEmbeddings = [];
this.parentNode = null;
this.include = InclusionType.UNSPECIFIED;
}
};
......@@ -419,6 +432,7 @@ class MetanodeImpl implements Metanode {
deviceHistogram: {[op: string]: number};
parentNode: Node;
hasNonControlEdges: boolean;
include: InclusionType;
/** A label object for meta-nodes in the graph hierarchy */
constructor(name: string, opt = {}) {
......@@ -448,6 +462,7 @@ class MetanodeImpl implements Metanode {
this.parentNode = null;
this.stats = new NodeStats(0, 0, null);
this.hasNonControlEdges = false;
this.include = InclusionType.UNSPECIFIED;
}
getFirstChild(): GroupNode|OpNode {
......@@ -599,6 +614,7 @@ class SeriesNodeImpl implements SeriesNode {
parentNode: Node;
deviceHistogram: {[op: string]: number};
hasNonControlEdges: boolean;
include: InclusionType;
constructor(prefix: string, suffix: string, parent: string,
clusterId: number, name: string) {
......@@ -619,6 +635,7 @@ class SeriesNodeImpl implements SeriesNode {
this.deviceHistogram = {};
this.hasNonControlEdges = false;
this.stats = new NodeStats(0, 0, null);
this.include = InclusionType.UNSPECIFIED;
}
}
......@@ -901,4 +918,15 @@ export function getHierarchicalPath(name: string,
return path;
};
/**
* Returns the string for the node inclusion toggle button, dependant
* on the provided current InclusionType.
*/
export function getIncludeNodeButtonString(include: InclusionType) {
if (include === tf.graph.InclusionType.EXCLUDE) {
return "Add to main graph";
} else {
return "Remove from main graph";
}
};
} // close module tf.graph
......@@ -288,6 +288,24 @@ export class RenderGraphInformation {
setGroupNodeDepth(this.root, +depth);
}
/**
* Returns true if the renderNode is an isolated node within its parent node.
*/
isNodeAuxilliary(renderNode: RenderNodeInformation): boolean {
let parentNode = <RenderGroupNodeInformation>this.getRenderNodeByName(
renderNode.node.parentNode.name);
let found = _.find(parentNode.isolatedInExtract, node => {
return node.node.name === renderNode.node.name;
});
if (found) {
return true;
}
found = _.find(parentNode.isolatedOutExtract, node => {
return node.node.name === renderNode.node.name;
});
return !!found;
}
buildSubhierarchy(nodeName: string): void {
// Terminate if the rendering hierarchy was already constructed
// for this node.
......@@ -555,6 +573,7 @@ export class RenderGraphInformation {
cardinality: 0,
parentNode: null,
stats: null,
include: InclusionType.UNSPECIFIED,
// BridgeNode properties.
inbound: inbound,
};
......@@ -573,6 +592,7 @@ export class RenderGraphInformation {
cardinality: 1,
parentNode: null,
stats: null,
include: InclusionType.UNSPECIFIED,
// BridgeNode properties.
inbound: inbound,
};
......@@ -692,6 +712,7 @@ export class RenderGraphInformation {
cardinality: 1,
parentNode: null,
stats: null,
include: InclusionType.UNSPECIFIED,
// BridgeNode properties.
inbound: inbound,
};
......@@ -1127,6 +1148,16 @@ function createShortcut(graph: graphlib.Graph<RenderNodeInformation, {}>,
let sink = graph.node(w);
let edge = graph.edge(v, w);
// If either of the nodes is explicitly included in the main graph and
// both nodes are in the main graph then do not create the shortcut
// and instead keep the real edge.
if ((src.node.include === InclusionType.INCLUDE ||
sink.node.include === InclusionType.INCLUDE) &&
src.node.include !== InclusionType.EXCLUDE &&
sink.node.include !== InclusionType.EXCLUDE) {
return;
}
// Add each annotation.
addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT, params);
addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT, params);
......@@ -1139,27 +1170,29 @@ function createShortcut(graph: graphlib.Graph<RenderNodeInformation, {}>,
* Remove edges from a node, and set its isOutExtract property to true,
* and remove the node and move it to isolatedOutExtract.
*
* If detachAllEdgesForHighDegree is true, extract all of its edges.
* Otherwise, only extract all in-edges.
* If detachAllEdgesForHighDegree or forceDetach is true, extract all of its
* edges. Otherwise, only extract all in-edges.
*/
function makeOutExtract(renderNode: RenderGroupNodeInformation, n: string,
params: RenderGraphParams) {
params: RenderGraphParams, forceDetach?: boolean) {
let graph = renderNode.coreGraph;
graph.node(n).isOutExtract = true;
let child = graph.node(n);
child.isOutExtract = true;
_.each(graph.predecessors(n), (p, index) => {
createShortcut(graph, p, n, params);
});
if (params.detachAllEdgesForHighDegree) {
if (params.detachAllEdgesForHighDegree || forceDetach) {
_.each(graph.successors(n), (s, index) => {
createShortcut(graph, n, s, params);
});
}
if (params.detachAllEdgesForHighDegree || graph.neighbors(n).length === 0) {
renderNode.isolatedOutExtract.push(graph.node(n));
// Remove the node from the core graph if it no longer has neighbors.
if (graph.neighbors(n).length === 0) {
child.node.include = InclusionType.EXCLUDE;
renderNode.isolatedOutExtract.push(child);
graph.removeNode(n);
}
}
......@@ -1167,27 +1200,30 @@ function makeOutExtract(renderNode: RenderGroupNodeInformation, n: string,
/**
* Remove edges from a node, set its isInExtract property to true,
* and remove the node and move it to isolatedInExtract.
* If detachAllEdgesForHighDegree is true, extract all of its edges.
* Otherwise, only remove all out-edges.
*
* If detachAllEdgesForHighDegree or forceDetach is true, extract all of its
* edges. Otherwise, only remove all out-edges.
*/
function makeInExtract(renderNode: RenderGroupNodeInformation, n: string,
params: RenderGraphParams) {
export function makeInExtract(renderNode: RenderGroupNodeInformation, n: string,
params: RenderGraphParams, forceDetach?: boolean) {
let graph = renderNode.coreGraph;
graph.node(n).isInExtract = true;
let child = graph.node(n);
child.isInExtract = true;
_.each(graph.successors(n), (s, index) => {
createShortcut(graph, n, s, params);
});
if (params.detachAllEdgesForHighDegree) {
if (params.detachAllEdgesForHighDegree || forceDetach) {
_.each(graph.predecessors(n), (p, index) => {
createShortcut(graph, p, n, params);
});
}
// Remove the node from the core graph if conditions are met.
if (params.detachAllEdgesForHighDegree || graph.neighbors(n).length === 0) {
renderNode.isolatedInExtract.push(graph.node(n));
// Remove the node from the core graph if it no longer has neighbors.
if (graph.neighbors(n).length === 0) {
child.node.include = InclusionType.EXCLUDE;
renderNode.isolatedInExtract.push(child);
graph.removeNode(n);
}
}
......@@ -1214,12 +1250,32 @@ function hasTypeIn(node: Node, types: string[]): boolean {
return false;
}
/** Move nodes that are speficied to be excluded out of the core graph. */
function extractSpeficiedNodes(renderNode: RenderGroupNodeInformation,
params: RenderGraphParams) {
let graph = renderNode.coreGraph;
_.each(graph.nodes(), n => {
let renderInfo = graph.node(n);
if (renderInfo.node.include === InclusionType.EXCLUDE) {
if (renderNode.coreGraph.outEdges(n).length >
renderNode.coreGraph.inEdges(n).length) {
makeOutExtract(renderNode, n, params, true);
} else {
makeInExtract(renderNode, n, params, true);
}
}
});
}
/** Remove edges from pre-defined out-extract patterns */
function extractPredefinedSink(renderNode: RenderGroupNodeInformation,
params: RenderGraphParams) {
let graph = renderNode.coreGraph;
_.each(graph.nodes(), n => {
let renderInfo = graph.node(n);
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
return;
}
if (hasTypeIn(renderInfo.node, params.outExtractTypes)) {
makeOutExtract(renderNode, n, params);
}
......@@ -1233,6 +1289,9 @@ function extractPredefinedSource(renderNode: RenderGroupNodeInformation,
_.each(graph.nodes(), n => {
let renderInfo = graph.node(n);
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
return;
}
if (hasTypeIn(renderInfo.node, params.inExtractTypes)) {
makeInExtract(renderNode, n, params);
}
......@@ -1247,6 +1306,9 @@ function extractHighInDegree(renderNode: RenderGroupNodeInformation,
// detect first so degrees don't get affected by other removal
let highInDegreeNames = _.filter(graph.nodes(), n => {
if (graph.node(n).node.include !== InclusionType.UNSPECIFIED) {
return false;
}
// Count the in-degree based on only regular edges, unless there are
// no regular edges, in which case use the number of control edges.
// This is done so that control edges don't effect if nodes are extracted
......@@ -1274,6 +1336,9 @@ function extractHighOutDegree(renderNode: RenderGroupNodeInformation,
// detect first so degrees don't get affected by other removal
let highOutDegreeNames = _.filter(graph.nodes(), n => {
if (graph.node(n).node.include !== InclusionType.UNSPECIFIED) {
return false;
}
// Count the out-degree based on only regular edges, unless there are
// no regular edges, in which case use the number of control edges.
// This is done so that control edges don't effect if nodes are extracted
......@@ -1345,6 +1410,9 @@ export function mapIndexToHue(id: number): number {
*/
function extractHighDegrees(renderNode: RenderGroupNodeInformation,
params: RenderGraphParams) {
extractSpeficiedNodes(renderNode, params);
if (params.outExtractTypes) {
extractPredefinedSink(renderNode, params);
}
......@@ -1386,7 +1454,9 @@ function extractHighDegrees(renderNode: RenderGroupNodeInformation,
_.each(graph.nodes(), n => {
let child = graph.node(n);
let degree = graph.neighbors(n).length;
if (child.node.include !== InclusionType.UNSPECIFIED) {
return;
}
if (degree === 0) {
let hasOutAnnotations = child.outAnnotations.list.length > 0;
let hasInAnnotations = child.inAnnotations.list.length > 0;
......@@ -1395,20 +1465,24 @@ function extractHighDegrees(renderNode: RenderGroupNodeInformation,
// This case only happens if detachAllEdgesForHighDegree is false.
// (Otherwise all source-like nodes are all isolated already.)
renderNode.isolatedInExtract.push(child);
child.node.include = InclusionType.EXCLUDE;
graph.removeNode(n);
} else if (child.isOutExtract) { // Is sink-like.
// This case only happens if detachAllEdgesForHighDegree is false.
// // (Otherwise all sink-like nodes are all isolated already.)
renderNode.isolatedOutExtract.push(child);
child.node.include = InclusionType.EXCLUDE;
graph.removeNode(n);
} else if (params.extractIsolatedNodesWithAnnotationsOnOneSide) {
if (hasOutAnnotations && !hasInAnnotations) {
child.isInExtract = true; // for ones with high out-annotations
renderNode.isolatedInExtract.push(child);
child.node.include = InclusionType.EXCLUDE;
graph.removeNode(n);
} else if (hasInAnnotations && !hasOutAnnotations) {
child.isOutExtract = true; // for ones with high in-annotations
renderNode.isolatedOutExtract.push(child);
child.node.include = InclusionType.EXCLUDE;
graph.removeNode(n);
} else {
// if a low degree node has both in- & out- annotations, do nothing
......
......@@ -17,6 +17,7 @@ limitations under the License.
/// <reference path="../render.ts" />
/// <reference path="scene.ts" />
/// <reference path="edge.ts" />
/// <reference path="contextmenu.ts" />
module tf.graph.scene.annotation {
......@@ -90,7 +91,7 @@ export function buildGroup(container, annotationData: render.AnnotationList,
let aGroup = d3.select(this);
update(aGroup, d, a, sceneBehavior);
if (a.annotationType !== tf.graph.render.AnnotationType.ELLIPSIS) {
addInteraction(aGroup, d, sceneBehavior);
addInteraction(aGroup, d, a, sceneBehavior);
}
});
......@@ -151,7 +152,7 @@ function addAnnotationLabel(aGroup, label, a, additionalClassNames,
}
function addInteraction(selection, d: render.RenderNodeInformation,
sceneBehavior) {
annotation: tf.graph.render.Annotation, sceneBehavior) {
selection
.on("mouseover", a => {
sceneBehavior.fire("annotation-highlight", {
......@@ -174,6 +175,11 @@ function addInteraction(selection, d: render.RenderNodeInformation,
hostName: d.node.name
});
});
if (annotation.annotationType !== tf.graph.render.AnnotationType.SUMMARY &&
annotation.annotationType !== tf.graph.render.AnnotationType.CONSTANT) {
selection.on("contextmenu", tf.graph.scene.contextmenu.getMenu(
tf.graph.scene.node.getContextMenu(annotation.node, sceneBehavior)));
}
};
/**
......
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
module tf.graph.scene.contextmenu {
/** Function that converts data to a title string. */
export interface TitleFunction {
(data: any): string;
}
/** Function that takes action based on item clicked in the context menu. */
export interface ActionFunction {
(elem: any, d: any, i: number): void;
}
/**
* The interface for an item in the context menu
*/
export interface ContextMenuItem {
title: TitleFunction;
action: ActionFunction;
}
/**
* Returns the event listener, which can be used as an argument for the d3
* selection.on function. Renders the context menu that is to be displayed
* in response to the event.
*/
export function getMenu(menu: ContextMenuItem[]) {
let menuSelection = d3.select(".context-menu");
// Close the menu when anything else is clicked.
d3.select("body").on("click.context", function() {
menuSelection.style("display", "none");
});
// Function called to populate the context menu.
return function(data, index: number): void {
// Position and display the menu.
let event = <MouseEvent>d3.event;
menuSelection.style({
"display": "block",
"left": (event.layerX + 1) + "px",
"top": (event.layerY + 1) + "px"
});
// Stop the event from propagating further.
event.preventDefault();
event.stopPropagation();
// Add provided items to the context menu.
menuSelection.html("");
let list = menuSelection.append("ul");
list.selectAll("li").data(menu).enter()
.append("li")
.html(function(d) {
return d.title(data);
})
.on("click", (d, i) => {
d.action(this, data, index);
menuSelection.style("display", "none");
});
};
};
} // close module
......@@ -16,6 +16,7 @@ limitations under the License.
/// <reference path="../graph.ts" />
/// <reference path="scene.ts" />
/// <reference path="annotation.ts" />
/// <reference path="contextmenu.ts" />
module tf.graph.scene.node {
......@@ -229,31 +230,52 @@ function addInteraction(selection, d: render.RenderNodeInformation,
selection.attr("pointer-events", "none");
return;
}
let contextMenuFunction = tf.graph.scene.contextmenu.getMenu(
getContextMenu(d.node, sceneBehavior));
selection.on("dblclick", d => {
sceneBehavior.fire("node-toggle-expand", { name: d.node.name });
sceneBehavior.fire("node-toggle-expand", { name: d.node.name });
})
.on("mouseover", d => {
// don't send mouseover over expanded group,
// otherwise it is causing too much glitches
if (sceneBehavior.isNodeExpanded(d)) { return; }
if (sceneBehavior.isNodeExpanded(d)) { return; }
sceneBehavior.fire("node-highlight", { name: d.node.name });
})
.on("mouseout", d => {
// don't send mouseover over expanded group,
// otherwise it is causing too much glitches
if (sceneBehavior.isNodeExpanded(d)) { return; }
if (sceneBehavior.isNodeExpanded(d)) { return; }
sceneBehavior.fire("node-unhighlight", { name: d.node.name });
sceneBehavior.fire("node-unhighlight", { name: d.node.name });
})
.on("click", d => {
// Stop this event's propagation so that it isn't also considered
// a graph-select.
(<Event>d3.event).stopPropagation();
sceneBehavior.fire("node-select", { name: d.node.name });
})
.on("contextmenu", (d, i) => {
sceneBehavior.fire("node-select", { name: d.node.name });
contextMenuFunction.call(d, i);
});
};
/**
* Returns the d3 context menu specification for the provided node.
*/
export function getContextMenu(node: Node, sceneBehavior) {
return [{
title: d => {
return tf.graph.getIncludeNodeButtonString(node.include);
},
action: (elm, d, i) => {
sceneBehavior.fire("node-toggle-extract", { name: node.name });
}
}];
}
/**
* Append svg text for label and assign data.
* @param nodeGroup
......
......@@ -13,5 +13,6 @@
<script src="lib/scene/annotation.js"></script>
<script src="lib/scene/edge.js"></script>
<script src="lib/scene/node.js"></script>
<script src="lib/scene/contextmenu.js"></script>
<script src="lib/layout.js"></script>
<script src="lib/colors.js"></script>
......@@ -83,6 +83,7 @@ by default. The user can select a different run from a dropdown menu.
}
.center {
position: relative;
height: 100%;
}
......
......@@ -22,6 +22,7 @@ h2 {
render-hierarchy="[[renderHierarchy]]"
flat-graph="[[graph]]"
node-name="[[selectedNode]]"
node-include="[[selectedNodeInclude]]"
highlighted-node="{{highlightedNode}}"
color-by="[[colorBy]]">
</tf-node-info>
......@@ -47,6 +48,11 @@ h2 {
highlightedNode: {
type: String,
notify: true
},
// The enum value of the include property of the selected node.
selectedNodeInclude: {
type: Number,
notify: true
}
},
listeners: {
......
......@@ -89,6 +89,23 @@
padding: 0;
}
.toggle-include-group {
padding-top: 4px;
}
.toggle-include {
margin: 5px 6px;
text-transform: none;
padding: 4px 6px;
font-size: 10pt;
background-color: #fafafa;
color: #666;
}
.toggle-include:hover {
background-color: var(--google-yellow-100);
}
.non-control-list-item {
padding-left: 10px;
}
......@@ -248,6 +265,11 @@
</div>
</template>
</div>
<div class="toggle-include-group">
<paper-button raised class="toggle-include" on-click="_toggleInclude">
<span>[[_auxButtonText]]</span>
</paper-button>
</div>
</div>
</template>
</iron-collapse>
......@@ -273,6 +295,11 @@
computed: '_getNode(nodeName, graphHierarchy)',
observer: '_resetState'
},
// The enum value of the include property of the selected node.
nodeInclude: {
type: Number,
observer: '_nodeIncludeStateChanged'
},
_attributes: {
type: Array,
computed: '_getAttributes(_node)'
......@@ -313,6 +340,7 @@
type: Boolean,
value: false
},
_auxButtonText: String
},
expandNode: function() {
this.fire('_node.expand', this.node);
......@@ -379,6 +407,16 @@
if (list) {
list.fire('iron-resize');
}
},
_toggleInclude: function() {
var graphElem = document.querySelector("#graph");
graphElem.fire("node-toggle-extract", { name: this.nodeName });
var graphBoardElem = document.querySelector("#graphboard");
graphBoardElem.fire("node-toggle-extract");
},
_nodeIncludeStateChanged: function(include, oldInclude) {
this.set("_auxButtonText",
tf.graph.getIncludeNodeButtonString(include));
}
});
})();
......
<link rel="import" href="../../../bower_components/polymer/polymer.html">
<link rel="import" href="../../polymer/polymer.html">
<link rel="import" href="../../tf-graph-board/tf-graph-board.html">
<link rel="import" href="../../tf-graph-loader/tf-graph-loader.html">
<link rel="import" href="../../tf-graph/tf-graph-controls.html">
......
......@@ -68,7 +68,7 @@ Module for adjusting render graph building parameter.
*/
detachAllEdgesForHighDegree: {
type: Boolean,
value: false
value: true
},
/**
......
......@@ -82,7 +82,6 @@ Polymer({
type: Object,
readOnly: true,
notify: true,
computed: '_buildRenderHierarchy(graphHierarchy, _graphParams)'
},
// internal properties
_graphParams: {
......@@ -100,8 +99,11 @@ Polymer({
value: true
}
},
observers: [
'_buildRenderHierarchy(graphHierarchy, _graphParams)'
],
_buildRenderHierarchy: function(graphHierarchy, params) {
return tf.time('new tf.graph.render.Hierarchy', function() {
tf.time('new tf.graph.render.Hierarchy', function() {
if (graphHierarchy.root.type !== tf.graph.NodeType.META) {
// root must be metanode but sometimes Polymer's dom-if has not
// remove tf-graph element yet in <tf-node-info>
......@@ -135,7 +137,7 @@ Polymer({
};
})
});
return renderGraph;
this._setRenderHierarchy(renderGraph);
}.bind(this));
},
_getVisible: function(name) {
......@@ -153,6 +155,7 @@ Polymer({
'node-select': '_nodeSelected',
'node-highlight': '_nodeHighlighted',
'node-unhighlight': '_nodeUnhighlighted',
'node-toggle-extract': '_nodeToggleExtract',
// Annotations
......@@ -214,6 +217,23 @@ Polymer({
// Also select the expanded node.
this._nodeSelected(event);
},
_nodeToggleExtract: function(event) {
// Toggle the include setting of the specified node appropriately.
var nodeName = event.detail.name;
var renderNode = this.renderHierarchy.getRenderNodeByName(nodeName);
if (renderNode.node.include == tf.graph.InclusionType.INCLUDE) {
renderNode.node.include = tf.graph.InclusionType.EXCLUDE;
} else if (renderNode.node.include == tf.graph.InclusionType.EXCLUDE) {
renderNode.node.include = tf.graph.InclusionType.INCLUDE;
} else {
renderNode.node.include =
this.renderHierarchy.isNodeAuxilliary(renderNode)
? tf.graph.InclusionType.INCLUDE : tf.graph.InclusionType.EXCLUDE;
}
// Rebuild the render hierarchy.
this._buildRenderHierarchy(this.graphHierarchy, this._graphParams);
},
not: function(x) {
return !x;
}
......
......@@ -284,30 +284,33 @@ _py_wrap_cc = rule(attrs={
},
implementation=_py_wrap_cc_impl,)
def tf_extension_linkopts():
return [] # No extension link opts
def tf_extension_copts():
return [] # No extension c opts
def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs):
module_name = name.split("/")[-1]
# Convert a rule name such as foo/bar/baz to foo/bar/_baz.so
# and use that as the name for the rule producing the .so file.
cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"])
extra_deps = []
_py_wrap_cc(name=name + "_py_wrap",
srcs=srcs,
swig_includes=swig_includes,
deps=deps,
deps=deps + extra_deps,
module_name=module_name,
py_module_name=name)
native.cc_binary(
name=cc_library_name,
srcs=[module_name + ".cc"],
copts=copts + ["-Wno-self-assign", "-Wno-write-strings"
] + ["-I/usr/include/python2.7"],
copts=(copts + ["-Wno-self-assign", "-Wno-write-strings"]
+ tf_extension_copts()),
linkopts=tf_extension_linkopts(),
linkstatic=1,
linkshared=1,
deps=deps)
deps=deps + extra_deps)
native.py_library(name=name,
srcs=[":" + name + ".py"],
srcs_version="PY2AND3",
......
licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "headers",
hdrs = glob([
"numpy_include/**/*.h",
]),
data = ["//util/python:python_checked"],
includes = [
"numpy_include",
],
)
licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "python_headers",
hdrs = glob([
"python_include/**/*.h",
]),
data = [":python_checked"],
includes = ["python_include"],
)
genrule(
name = "python_check",
srcs = [
"python_config.sh",
],
outs = [
"python_checked",
],
cmd = "OUTPUTDIR=\"$(@D)/\"; ./util/python/python_config.sh --check && touch $$OUTPUTDIR/python_checked",
local = 1,
)
#!/bin/bash
set -e -o errexit
EXPECTED_PATHS="util/python/python_include util/python/python_lib third_party/py/numpy/numpy_include"
function main {
argument="$1"
shift
case $argument in
--check)
check_python
exit 0
;;
--setup)
setup_python "$1"
exit 0
;;
esac
}
function setup_python {
PYTHON_BIN_PATH="$1";
if [ -z "$PYTHON_BIN_PATH" ]; then
echo "PYTHON_BIN_PATH was not provided. Did you run configure?"
exit 1
fi
if [ ! -x "$PYTHON_BIN_PATH" ] || [ -d "$PYTHON_BIN_PATH" ]; then
echo "PYTHON_BIN_PATH is not executable. Is it the python binary?"
exit 1
fi
local python_include=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; from distutils import sysconfig; print(sysconfig.get_python_inc());')
if [ "$python_include" == "" ]; then
echo -e "\n\nERROR: Problem getting python include path. Is distutils installed?"
exit 1
fi
local python_lib=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; from distutils import sysconfig; print(sysconfig.get_python_lib());')
if [ "$python_lib" == "" ]; then
echo -e "\n\nERROR: Problem getting python lib path. Is distutils installed?"
exit 1
fi
local numpy_include=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import numpy; print(numpy.get_include());')
if [ "$numpy_include" == "" ]; then
echo -e "\n\nERROR: Problem getting numpy include path. Is numpy installed?"
exit 1
fi
for x in $EXPECTED_PATHS; do
if [ -e "$x" ]; then
rm "$x"
fi
done
ln -s "${python_include}" util/python/python_include
ln -s "${python_lib}" util/python/python_lib
ln -s "${numpy_include}" third_party/py/numpy/numpy_include
}
function check_python {
for x in $EXPECTED_PATHS; do
if [ ! -e "$x" ]; then
echo -e "\n\nERROR: Cannot find '${x}'. Did you run configure?\n\n" 1>&2
exit 1
fi
if [ ! -L "${x}" ]; then
echo -e "\n\nERROR: '${x}' is not a symbolic link. Internal error.\n\n" 1>&2
exit 1
fi
true_path=$(readlink "${x}")
if [ ! -d "${true_path}" ]; then
echo -e "\n\nERROR: '${x}' does not refer to an existing directory: ${true_path}. Do you need to rerun configure?\n\n" 1>&2
exit 1
fi
done
}
main "$@"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册