提交 fa095c5d 编写于 作者: M Manjunath Kudlur

TensorFlow: upstream changes to git

Change 109195845
	Fix TensorFlow for build against Bazel 0.1.2rc2

	Two things are currently broken with TensorFlow and Bazel 0.1.2:
	  - Bazel now use sandboxing by default on Linux and have fixed it for cc_* rules.
	    Undeclared headers are not mounted in the sandbox which make several cc_* rules
	    fails.
	  - Bazel now enforce strict header checking and some target were missing
	    headers even though the headers were mounted in the sandbox. This change
	    adds a "strict_headers" target that globs every headers of the core
	    library and add it to the `tf_cc_tests` targets.
Change 109162708
	Fix various website issues
	- Fix headline in os_setup.md
	- Fix #anchor links
Change 109162129
	Fix numbers in mnist tutorial, fixes #362
Change 109158967
	Fix typo in word2vec tutorial, fixes #347
Change 109151855
	Fix tile and its gradient for scalars on GPUs

	Eigen doesn't handle scalars on GPUs in all cases.  Fortunately, both
	tile and its gradient are the identity for scalars, so we can just copy
	the input to the output.

	Fixes https://github.com/tensorflow/tensorflow/issues/391.
Change 109140763
	Support int32 and int64 in tf.random_uniform

	This requires a new RandomUniformInt op on the C++ side since the op needs
	to know minval and maxval.

	Fixes https://github.com/tensorflow/tensorflow/issues/364.
Change 109140738
	Fix spacing in docs.
Change 109140030
	Fix content nav to not hide the bottom 100 or so px.
Change 109139967
	Add license files to TensorBoard files, fix mnist_with_summaries test
Change 109138333
	Fix typos in docstring
Change 109138098
	Fix some missing resources in the website.

	Fixes #366.
Change 109123771
	Make sparse_to_dense's default_value default to 0

	Nearly all uses of sparse_to_dense use 0 as the default.  The
	same goes for sparse_tensor_to_dense.

Base CL: 109198336
上级 f586a5ee
......@@ -32,7 +32,7 @@ bind(
git_repository(
name = "re2",
remote = "https://github.com/google/re2.git",
tag = "2015-07-01",
commit = "791beff",
)
new_http_archive(
......
......@@ -101,7 +101,10 @@ tf_cuda_library(
"**/*main.cc",
],
),
hdrs = glob(["public/**/*.h"]),
hdrs = glob([
"public/**/*.h",
"util/device_name_utils.h",
]),
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
......@@ -345,6 +348,12 @@ cc_library(
alwayslink = 1,
)
# This is to workaround strict header checks
cc_library(
name = "strict_headers",
hdrs = glob(["**/*.h"]),
)
# Low level library tests
tf_cc_tests(
tests = glob(
......@@ -356,6 +365,7 @@ tf_cc_tests(
),
deps = [
":lib",
":strict_headers",
":test_main",
],
)
......@@ -404,6 +414,7 @@ tf_cc_tests(
":direct_session",
":kernels",
":lib",
":strict_headers",
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
......@@ -424,6 +435,7 @@ tf_cc_tests(
deps = [
":direct_session",
":kernels",
":strict_headers",
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
......
......@@ -46,7 +46,7 @@ template <typename Device, class Distribution>
struct FillPhiloxRandom {
typedef typename Distribution::ResultElementType T;
void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen,
T* data, int64 size) {
T* data, int64 size, Distribution dist) {
LOG(FATAL) << "Default FillPhiloxRandom should not be executed.";
}
};
......@@ -57,7 +57,8 @@ template <class Distribution>
struct FillPhiloxRandom<GPUDevice, Distribution> {
typedef typename Distribution::ResultElementType T;
void operator()(OpKernelContext* ctx, const GPUDevice&,
random::PhiloxRandom gen, T* data, int64 size);
random::PhiloxRandom gen, T* data, int64 size,
Distribution dist);
};
#endif
......@@ -72,8 +73,7 @@ template <class Distribution>
struct FillPhiloxRandomTask<Distribution, false> {
typedef typename Distribution::ResultElementType T;
static void Run(random::PhiloxRandom gen, T* data, int64 size,
int64 start_group, int64 limit_group) {
Distribution dist;
int64 start_group, int64 limit_group, Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
gen.Skip(start_group);
......@@ -96,7 +96,7 @@ struct FillPhiloxRandomTask<Distribution, false> {
}
};
// Specialization for distribution that takes a varaiable number of samples for
// Specialization for distribution that takes a variable number of samples for
// each output. This will be slower due to the generality.
template <class Distribution>
struct FillPhiloxRandomTask<Distribution, true> {
......@@ -104,11 +104,10 @@ struct FillPhiloxRandomTask<Distribution, true> {
static const int64 kReservedSamplesPerOutput = 256;
static void Run(random::PhiloxRandom base_gen, T* data, int64 size,
int64 start_group, int64 limit_group) {
int64 start_group, int64 limit_group, Distribution dist) {
using random::PhiloxRandom;
using random::SingleSampleAdapter;
Distribution dist;
const int kGroupSize = Distribution::kResultElementCount;
static const int kGeneratorSkipPerOutputGroup =
......@@ -153,7 +152,8 @@ template <class Distribution>
struct FillPhiloxRandom<CPUDevice, Distribution> {
typedef typename Distribution::ResultElementType T;
void operator()(OpKernelContext* context, const CPUDevice&,
random::PhiloxRandom gen, T* data, int64 size) {
random::PhiloxRandom gen, T* data, int64 size,
Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
......@@ -164,17 +164,49 @@ struct FillPhiloxRandom<CPUDevice, Distribution> {
// sub-linear. Too many threads causes a much worse overall performance.
int num_workers = 6;
Shard(num_workers, worker_threads.workers, total_group_count, kGroupSize,
[&gen, data, size](int64 start_group, int64 limit_group) {
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
FillPhiloxRandomTask<
Distribution,
Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
start_group,
limit_group);
limit_group,
dist);
});
}
};
} // namespace functor
namespace {
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
int index, Tensor** output) {
if (!TensorShapeUtils::IsLegacyVector(shape.shape())) {
return errors::InvalidArgument(
"shape must be a vector of {int32,int64}, got shape ",
shape.shape().ShortDebugString());
}
if (shape.dtype() == DataType::DT_INT32) {
auto vec = shape.flat<int32>();
TF_RETURN_IF_ERROR(ctx->allocate_output(
index, TensorShapeUtils::MakeShape(vec.data(), vec.size()), output));
} else if (shape.dtype() == DataType::DT_INT64) {
auto vec = shape.flat<int64>();
TF_RETURN_IF_ERROR(ctx->allocate_output(
index, TensorShapeUtils::MakeShape(vec.data(), vec.size()), output));
} else {
return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
}
return Status::OK();
}
// Reserve enough random samples in the generator for the given output count.
// Note that the 256 multiplier is repeated above; do not change it just here.
static random::PhiloxRandom ReserveRandomOutputs(GuardedPhiloxRandom& generator,
int64 output_count) {
int64 conservative_sample_count = output_count << 8;
return generator.ReserveSamples128(conservative_sample_count);
}
// For now, use the same interface as RandomOp, so we can choose either one
// at the run-time.
template <typename Device, class Distribution>
......@@ -186,41 +218,65 @@ class PhiloxRandomOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
OP_REQUIRES(
ctx, TensorShapeUtils::IsLegacyVector(input.shape()),
errors::InvalidArgument("shape must be a vector of {int32,int64}."));
Tensor* output = nullptr;
if (input.dtype() == DataType::DT_INT32) {
auto vec = input.flat<int32>();
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShapeUtils::MakeShape(
vec.data(), vec.size()),
&output));
} else if (input.dtype() == DataType::DT_INT64) {
auto vec = input.flat<int64>();
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShapeUtils::MakeShape(
vec.data(), vec.size()),
&output));
} else {
OP_REQUIRES(ctx, false, errors::InvalidArgument(
"shape must be a vector of {int32,int64}."));
}
const Tensor& shape = ctx->input(0);
Tensor* output;
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<T>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),
ReserveRandomOutputs(output->flat<T>().size()),
output->flat<T>().data(), output->flat<T>().size());
ReserveRandomOutputs(generator_, output_flat.size()),
output_flat.data(), output_flat.size(), Distribution());
}
private:
GuardedPhiloxRandom generator_;
};
template <typename Device, class IntType>
class RandomUniformIntOp : public OpKernel {
public:
explicit RandomUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, generator_.Init(ctx));
}
// Reserve enough random samples in the generator for the given output count.
random::PhiloxRandom ReserveRandomOutputs(int64 output_count) {
int64 conservative_sample_count = output_count << 8;
return generator_.ReserveSamples128(conservative_sample_count);
void Compute(OpKernelContext* ctx) override {
const Tensor& shape = ctx->input(0);
const Tensor& minval = ctx->input(1);
const Tensor& maxval = ctx->input(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
errors::InvalidArgument("minval must be 0-D, got shape ",
minval.shape().ShortDebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().ShortDebugString()));
// Verify that minval < maxval
IntType lo = minval.scalar<IntType>()();
IntType hi = maxval.scalar<IntType>()();
OP_REQUIRES(
ctx, lo < hi,
errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
// Build distribution
typedef random::UniformDistribution<random::PhiloxRandom, IntType>
Distribution;
Distribution dist(lo, hi);
Tensor* output;
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<IntType>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),
ReserveRandomOutputs(generator_, output_flat.size()),
output_flat.data(), output_flat.size(), dist);
}
private:
GuardedPhiloxRandom generator_;
};
} // namespace
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("RandomUniform") \
......@@ -246,10 +302,22 @@ class PhiloxRandomOp : public OpKernel {
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
#define REGISTER_INT(IntType) \
REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
.Device(DEVICE_CPU) \
.HostMemory("shape") \
.HostMemory("minval") \
.HostMemory("maxval") \
.TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<CPUDevice, IntType>);
REGISTER(float);
REGISTER(double);
REGISTER_INT(int32);
REGISTER_INT(int64);
#undef REGISTER
#undef REGISTER_INT
#if GOOGLE_CUDA
......@@ -281,10 +349,23 @@ REGISTER(double);
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
#define REGISTER_INT(IntType) \
REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
.Device(DEVICE_GPU) \
.HostMemory("shape") \
.HostMemory("minval") \
.HostMemory("maxval") \
.TypeConstraint<int32>("T") \
.TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<GPUDevice, IntType>);
REGISTER(float);
REGISTER(double);
REGISTER_INT(int32);
REGISTER_INT(int64);
#undef REGISTER
#undef REGISTER_INT
#endif // GOOGLE_CUDA
......
......@@ -42,8 +42,8 @@ struct FillPhiloxRandomKernel;
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, false> {
typedef typename Distribution::ResultElementType T;
PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size) {
Distribution dist;
PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size,
Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -74,7 +74,7 @@ template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, true> {
typedef typename Distribution::ResultElementType T;
PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data,
int64 size) {
int64 size, Distribution dist) {
using random::PhiloxRandom;
using random::SingleSampleAdapter;
......@@ -88,7 +88,6 @@ struct FillPhiloxRandomKernel<Distribution, true> {
const int32 total_thread_count = gridDim.x * blockDim.x;
int64 group_index = thread_id;
int64 offset = group_index * kGroupSize;
Distribution dist;
while (offset < size) {
// Since each output takes a variable number of samples, we need to
......@@ -118,10 +117,10 @@ template <class Distribution>
__global__ void __launch_bounds__(1024)
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
typename Distribution::ResultElementType* data,
int64 size) {
int64 size, Distribution dist) {
FillPhiloxRandomKernel<Distribution,
Distribution::kVariableSamplesPerOutput>()
.Run(base_gen, data, size);
.Run(base_gen, data, size, dist);
}
// Partial specialization for GPU
......@@ -130,7 +129,7 @@ struct FillPhiloxRandom<GPUDevice, Distribution> {
typedef typename Distribution::ResultElementType T;
typedef GPUDevice Device;
void operator()(OpKernelContext*, const Device& d, random::PhiloxRandom gen,
T* data, int64 size) {
T* data, int64 size, Distribution dist) {
const int32 block_size = d.maxCudaThreadsPerBlock();
const int32 num_blocks =
(d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) /
......@@ -138,7 +137,7 @@ struct FillPhiloxRandom<GPUDevice, Distribution> {
FillPhiloxRandomKernelLaunch<
Distribution><<<num_blocks, block_size, 0, d.stream()>>>(gen, data,
size);
size, dist);
}
};
......@@ -149,6 +148,10 @@ template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, float> >;
template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, double> >;
template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int32> >;
template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int64> >;
template struct FillPhiloxRandom<
GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
template struct FillPhiloxRandom<
......
......@@ -52,11 +52,16 @@ class TileOp : public OpKernel {
errors::InvalidArgument(
"Expected multiples argument to be a vector of length ",
input.dims(), " but got length ", multiples.dim_size(0)));
const int input_dims = input.dims();
// Eigen doesn't support scalars on the GPU, so handle 0-D specially
if (input_dims == 0) {
context->set_output(0, input);
return;
}
const gtl::ArraySlice<int32> multiples_array(multiples.flat<int32>().data(),
input_dims);
TensorShape output_shape;
for (int i = 0; i < input_dims; ++i) {
OP_REQUIRES(
......@@ -75,7 +80,6 @@ class TileOp : public OpKernel {
}
#define HANDLE_TYPE(T) \
HANDLE_DIM(T, 0) \
HANDLE_DIM(T, 1) \
HANDLE_DIM(T, 2) \
HANDLE_DIM(T, 3) \
......@@ -142,16 +146,13 @@ inline void TileOp<Device>::HandleCase(
HandleCaseImpl<dtype, ndim>(context, multiples_array, result); \
}
#define HANDLE_CASE_DIM_POSITIVE(device, dtype) \
HANDLE_CASE(device, dtype, 1); \
HANDLE_CASE(device, dtype, 2); \
HANDLE_CASE(device, dtype, 3); \
HANDLE_CASE(device, dtype, 4); \
HANDLE_CASE(device, dtype, 5);
// 0-D handled above
#define HANDLE_CASE_DIM(device, dtype) \
HANDLE_CASE(device, dtype, 0); \
HANDLE_CASE_DIM_POSITIVE(device, dtype);
HANDLE_CASE(device, dtype, 1); \
HANDLE_CASE(device, dtype, 2); \
HANDLE_CASE(device, dtype, 3); \
HANDLE_CASE(device, dtype, 4); \
HANDLE_CASE(device, dtype, 5);
HANDLE_CASE_DIM(CPUDevice, DT_BOOL);
HANDLE_CASE_DIM(CPUDevice, DT_FLOAT);
......@@ -163,15 +164,13 @@ HANDLE_CASE_DIM(CPUDevice, DT_INT64);
HANDLE_CASE_DIM(CPUDevice, DT_STRING);
#if GOOGLE_CUDA
// Eigen on GPU does not handle 0-dimension data types yet.
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_FLOAT);
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_DOUBLE);
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT16);
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT32);
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT64);
HANDLE_CASE_DIM(GPUDevice, DT_FLOAT);
HANDLE_CASE_DIM(GPUDevice, DT_DOUBLE);
HANDLE_CASE_DIM(GPUDevice, DT_INT16);
HANDLE_CASE_DIM(GPUDevice, DT_INT32);
HANDLE_CASE_DIM(GPUDevice, DT_INT64);
#endif // GOOGLE_CUDA
#undef HANDLE_CASE_DIM_POSITIVE
#undef HANDLE_CASE_DIM
#undef HANDLE_CASE
......@@ -194,9 +193,15 @@ class TileGradientOp : public OpKernel {
input.dims(), " but got length ", multiples.dim_size(0)));
const int input_dims = input.dims();
// Eigen doesn't support scalars on the GPU, so handle 0-D specially
if (input_dims == 0) {
context->set_output(0, input);
return;
}
const gtl::ArraySlice<int32> multiples_array(multiples.flat<int32>().data(),
input_dims);
TensorShape output_shape;
std::vector<int32> input_dim_size_vec;
for (int i = 0; i < input_dims; ++i) {
......@@ -223,7 +228,6 @@ class TileGradientOp : public OpKernel {
}
#define HANDLE_TYPE(T) \
HANDLE_DIM(T, 0) \
HANDLE_DIM(T, 1) \
HANDLE_DIM(T, 2) \
HANDLE_DIM(T, 3) \
......@@ -282,7 +286,7 @@ class TileGradientOp : public OpKernel {
// NOTE(keveman): Handling the most common case here.
// Adding more cases here would require more templating and code
// explosion. For instance, HANDLE_DIM(2) wouldn't make sense for NDIM=1.
HANDLE_DIM(NDIM > 0 ? 1 : 0);
HANDLE_DIM(1);
// Fall through to the unoptimized version.
#undef HANDLE_DIM
......@@ -362,16 +366,13 @@ inline void TileGradientOp<Device>::HandleCase(
HandleCaseImpl<dtype, ndim>(context, input_dims, multiples_array, result); \
}
#define HANDLE_CASE_DIM_POSITIVE(device, dtype) \
HANDLE_CASE(device, dtype, 1); \
HANDLE_CASE(device, dtype, 2); \
HANDLE_CASE(device, dtype, 3); \
HANDLE_CASE(device, dtype, 4); \
HANDLE_CASE(device, dtype, 5);
// 0-D handled specially above
#define HANDLE_CASE_DIM(device, dtype) \
HANDLE_CASE(device, dtype, 0); \
HANDLE_CASE_DIM_POSITIVE(device, dtype);
HANDLE_CASE(device, dtype, 1); \
HANDLE_CASE(device, dtype, 2); \
HANDLE_CASE(device, dtype, 3); \
HANDLE_CASE(device, dtype, 4); \
HANDLE_CASE(device, dtype, 5);
HANDLE_CASE_DIM(CPUDevice, DT_FLOAT);
HANDLE_CASE_DIM(CPUDevice, DT_DOUBLE);
......@@ -380,15 +381,13 @@ HANDLE_CASE_DIM(CPUDevice, DT_INT32);
HANDLE_CASE_DIM(CPUDevice, DT_INT64);
#if GOOGLE_CUDA
// Eigen on GPU does not handle 0-dimension data types yet.
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_FLOAT);
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_DOUBLE);
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT16);
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT32);
HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT64);
HANDLE_CASE_DIM(GPUDevice, DT_FLOAT);
HANDLE_CASE_DIM(GPUDevice, DT_DOUBLE);
HANDLE_CASE_DIM(GPUDevice, DT_INT16);
HANDLE_CASE_DIM(GPUDevice, DT_INT32);
HANDLE_CASE_DIM(GPUDevice, DT_INT64);
#endif // GOOGLE_CUDA
#undef HANDLE_CASE_DIM_POSITIVE
#undef HANDLE_CASE_DIM
#undef HANDLE_CASE
......
......@@ -88,6 +88,71 @@ class UniformDistribution<Generator, double> {
}
};
template <class Generator>
class UniformDistribution<Generator, int32> {
public:
// The number of elements that will be returned.
static const int kResultElementCount = Generator::kResultElementCount;
// Indicate that this distribution may take variable number of samples
// during the runtime.
static const bool kVariableSamplesPerOutput = false;
typedef Array<int32, kResultElementCount> ResultType;
typedef int32 ResultElementType;
// Must have lo < hi
UniformDistribution(int32 lo, int32 hi) : lo_(lo), range_(hi - lo) {}
PHILOX_DEVICE_INLINE
ResultType operator()(Generator* gen) {
typename Generator::ResultType sample = (*gen)();
ResultType result;
for (int i = 0; i < kResultElementCount; ++i) {
result[i] = lo_ + static_cast<int32>(sample[i] % range_);
}
return result;
}
private:
// Note that lo_ is intentionally signed while range_ is intentionally
// unsigned. This is because hi - lo can overflow signed integers if
// lo < 0 < hi, but always fits in unsigned.
int32 lo_;
uint32 range_;
};
template <class Generator>
class UniformDistribution<Generator, int64> {
public:
// The number of elements that will be returned.
static const int kResultElementCount = Generator::kResultElementCount / 2;
// Indicate that this distribution may take variable number of samples
// during the runtime.
static const bool kVariableSamplesPerOutput = false;
typedef Array<int64, kResultElementCount> ResultType;
typedef int64 ResultElementType;
// Must have lo < hi
UniformDistribution(int64 lo, int64 hi) : lo_(lo), range_(hi - lo) {}
PHILOX_DEVICE_INLINE
ResultType operator()(Generator* gen) {
typename Generator::ResultType sample = (*gen)();
ResultType result;
for (int i = 0; i < kResultElementCount; ++i) {
auto bits = sample[2 * i] | static_cast<uint64>(sample[2 * i + 1]) << 32;
result[i] = lo_ + static_cast<int64>(bits % range_);
}
return result;
}
private:
// Note that lo_ is intentionally signed while range_ is intentionally
// unsigned. This is because hi - lo can overflow signed integers if
// lo < 0 < hi, but always fits in unsigned.
int64 lo_;
uint64 range_;
};
// A class that adapts the underlying native multiple samples to return a single
// sample at a time.
template <class Generator>
......
......@@ -31,7 +31,7 @@ Input images can be of different types but output images are always float.
images: 4-D with shape `[batch, height, width, channels]`.
size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
resized_images: 4-D with shape
resized_images: 4-D with shape
`[batch, new_height, new_width, channels]`.
)doc");
......@@ -49,7 +49,7 @@ Input images can be of different types but output images are always float.
images: 4-D with shape `[batch, height, width, channels]`.
size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
resized_images: 4-D with shape
resized_images: 4-D with shape
`[batch, new_height, new_width, channels]`.
)doc");
......@@ -67,7 +67,7 @@ Input images can be of different types but output images are always float.
images: 4-D with shape `[batch, height, width, channels]`.
size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
resized_images: 4-D with shape
resized_images: 4-D with shape
`[batch, new_height, new_width, channels]`.
)doc");
......@@ -85,7 +85,7 @@ Input images can be of different types but output images are always float.
images: 4-D with shape `[batch, height, width, channels]`.
size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
resized_images: 4-D with shape
resized_images: 4-D with shape
`[batch, new_height, new_width, channels]`.
)doc");
......
......@@ -5107,6 +5107,68 @@ op {
description: "The generated values follow a uniform distribution in the range `[0, 1)`. The\nlower bound 0 is included in the range, while the upper bound 1 is excluded."
is_stateful: true
}
op {
name: "RandomUniformInt"
input_arg {
name: "shape"
description: "The shape of the output tensor."
type_attr: "T"
}
input_arg {
name: "minval"
description: "0-D. Inclusive lower bound on the generated integers."
type_attr: "Tout"
}
input_arg {
name: "maxval"
description: "0-D. Exclusive upper bound on the generated integers."
type_attr: "Tout"
}
output_arg {
name: "output"
description: "A tensor of the specified shape filled with uniform random integers."
type_attr: "Tout"
}
attr {
name: "seed"
type: "int"
default_value {
i: 0
}
description: "If either `seed` or `seed2` are set to be non-zero, the random number\ngenerator is seeded by the given seed. Otherwise, it is seeded by a\nrandom seed."
}
attr {
name: "seed2"
type: "int"
default_value {
i: 0
}
description: "A second seed to avoid seed collision."
}
attr {
name: "Tout"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
summary: "Outputs random integers from a uniform distribution."
description: "The generated values are uniform integers in the range `[minval, maxval)`.\nThe lower bound `minval` is included in the range, while the upper bound\n`maxval` is excluded.\n\nThe random integers are slightly biased unless `maxval - minval` is an exact\npower of two. The bias is small for values of `maxval - minval` significantly\nsmaller than the range of the output (either `2^32` or `2^64`)."
is_stateful: true
}
op {
name: "Range"
input_arg {
......@@ -5830,12 +5892,20 @@ op {
type: "int"
description: "The dimension which is partially reversed."
}
attr {
name: "batch_dim"
type: "int"
default_value {
i: 0
}
description: "The dimension along which reversal is performed."
}
attr {
name: "T"
type: "type"
}
summary: "Reverses variable length slices in dimension `seq_dim`."
description: "This op first slices `input` along the first dimension, and for each slice `i`,\nreverses the first `seq_lengths[i]` elements along the dimension `seq_dim`.\n\nThe elements of `seq_lengths` must obey `seq_lengths[i] < input.dims[seq_dim]`,\nand `seq_lengths` must be a vector of length `input.dims(0)`.\n\nThe output slice `i` along dimension 0 is then given by input slice `i`, with\nthe first `seq_lengths[i]` slices along dimension `seq_dim` reversed.\n\nFor example:\n\n```prettyprint\n# Given this:\nseq_dim = 1\ninput.dims = (4, ...)\nseq_lengths = [7, 2, 3, 5]\n\n# then slices of input are reversed on seq_dim, but only up to seq_lengths:\noutput[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]\noutput[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]\noutput[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]\noutput[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]\n\n# while entries past seq_lens are copied through:\noutput[0, 7:, :, ...] = input[0, 7:, :, ...]\noutput[1, 2:, :, ...] = input[1, 2:, :, ...]\noutput[2, 3:, :, ...] = input[2, 3:, :, ...]\noutput[3, 2:, :, ...] = input[3, 2:, :, ...]\n```"
summary: "Reverses variable length slices."
description: "This op first slices `input` along the dimension `batch_dim`, and for each\nslice `i`, reverses the first `seq_lengths[i]` elements along\nthe dimension `seq_dim`.\n\nThe elements of `seq_lengths` must obey `seq_lengths[i] < input.dims[seq_dim]`,\nand `seq_lengths` must be a vector of length `input.dims[batch_dim]`.\n\nThe output slice `i` along dimension `batch_dim` is then given by input\nslice `i`, with the first `seq_lengths[i]` slices along dimension\n`seq_dim` reversed.\n\nFor example:\n\n```prettyprint\n# Given this:\nbatch_dim = 0\nseq_dim = 1\ninput.dims = (4, 8, ...)\nseq_lengths = [7, 2, 3, 5]\n\n# then slices of input are reversed on seq_dim, but only up to seq_lengths:\noutput[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]\noutput[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]\noutput[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]\noutput[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]\n\n# while entries past seq_lens are copied through:\noutput[0, 7:, :, ...] = input[0, 7:, :, ...]\noutput[1, 2:, :, ...] = input[1, 2:, :, ...]\noutput[2, 3:, :, ...] = input[2, 3:, :, ...]\noutput[3, 2:, :, ...] = input[3, 2:, :, ...]\n```\n\nIn contrast, if:\n```prettyprint\n# Given this:\nbatch_dim = 2\nseq_dim = 0\ninput.dims = (8, ?, 4, ...)\nseq_lengths = [7, 2, 3, 5]\n\n# then slices of input are reversed on seq_dim, but only up to seq_lengths:\noutput[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]\noutput[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]\noutput[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]\noutput[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...]\n\n# while entries past seq_lens are copied through:\noutput[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]\noutput[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]\noutput[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]\noutput[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]\n```"
}
op {
name: "Rsqrt"
......
......@@ -41,6 +41,38 @@ seed2: A second seed to avoid seed collision.
output: A tensor of the specified shape filled with uniform random values.
)doc");
REGISTER_OP("RandomUniformInt")
.Input("shape: T")
.Input("minval: Tout")
.Input("maxval: Tout")
.SetIsStateful()
.Output("output: Tout")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("Tout: {int32, int64}")
.Attr("T: {int32, int64}")
.Doc(R"doc(
Outputs random integers from a uniform distribution.
The generated values are uniform integers in the range `[minval, maxval)`.
The lower bound `minval` is included in the range, while the upper bound
`maxval` is excluded.
The random integers are slightly biased unless `maxval - minval` is an exact
power of two. The bias is small for values of `maxval - minval` significantly
smaller than the range of the output (either `2^32` or `2^64`).
shape: The shape of the output tensor.
minval: 0-D. Inclusive lower bound on the generated integers.
maxval: 0-D. Exclusive upper bound on the generated integers.
seed: If either `seed` or `seed2` are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
seed2: A second seed to avoid seed collision.
output: A tensor of the specified shape filled with uniform random integers.
)doc");
REGISTER_OP("RandomStandardNormal")
.Input("shape: T")
.SetIsStateful()
......
......@@ -690,25 +690,28 @@ This is the opposite of pack. The numpy equivalent is
- - -
### `tf.reverse_sequence(input, seq_lengths, seq_dim, name=None)` {#reverse_sequence}
### `tf.reverse_sequence(input, seq_lengths, seq_dim, batch_dim=None, name=None)` {#reverse_sequence}
Reverses variable length slices in dimension `seq_dim`.
Reverses variable length slices.
This op first slices `input` along the first dimension, and for each slice `i`,
reverses the first `seq_lengths[i]` elements along the dimension `seq_dim`.
This op first slices `input` along the dimension `batch_dim`, and for each
slice `i`, reverses the first `seq_lengths[i]` elements along
the dimension `seq_dim`.
The elements of `seq_lengths` must obey `seq_lengths[i] < input.dims[seq_dim]`,
and `seq_lengths` must be a vector of length `input.dims(0)`.
and `seq_lengths` must be a vector of length `input.dims[batch_dim]`.
The output slice `i` along dimension 0 is then given by input slice `i`, with
the first `seq_lengths[i]` slices along dimension `seq_dim` reversed.
The output slice `i` along dimension `batch_dim` is then given by input
slice `i`, with the first `seq_lengths[i]` slices along dimension
`seq_dim` reversed.
For example:
```prettyprint
# Given this:
batch_dim = 0
seq_dim = 1
input.dims = (4, ...)
input.dims = (4, 8, ...)
seq_lengths = [7, 2, 3, 5]
# then slices of input are reversed on seq_dim, but only up to seq_lengths:
......@@ -724,6 +727,27 @@ output[2, 3:, :, ...] = input[2, 3:, :, ...]
output[3, 2:, :, ...] = input[3, 2:, :, ...]
```
In contrast, if:
```prettyprint
# Given this:
batch_dim = 2
seq_dim = 0
input.dims = (8, ?, 4, ...)
seq_lengths = [7, 2, 3, 5]
# then slices of input are reversed on seq_dim, but only up to seq_lengths:
output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]
output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]
output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]
output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...]
# while entries past seq_lens are copied through:
output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]
output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]
output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]
output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]
```
##### Args:
......@@ -732,6 +756,8 @@ output[3, 2:, :, ...] = input[3, 2:, :, ...]
1-D with length `input.dims(0)` and
`max(seq_lengths) < input.dims(seq_dim)`
* <b>`seq_dim`</b>: An `int`. The dimension which is partially reversed.
* <b>`batch_dim`</b>: An optional `int`. Defaults to `0`.
The dimension along which reversal is performed.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
......
......@@ -157,13 +157,13 @@ Alias for field number 1
- - -
### `tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value, name=None)` {#sparse_to_dense}
### `tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0, name=None)` {#sparse_to_dense}
Converts a sparse representation into a dense tensor.
Builds an array `dense` with shape `output_shape` such that
```prettyprint
```python
# If sparse_indices is scalar
dense[i] = (i == sparse_indices ? sparse_values : default_value)
......@@ -174,34 +174,32 @@ dense[sparse_indices[i]] = sparse_values[i]
dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
```
All other values in `dense` are set to `default_value`. If `sparse_values` is a
scalar, all sparse indices are set to this single value.
All other values in `dense` are set to `default_value`. If `sparse_values`
is a scalar, all sparse indices are set to this single value.
##### Args:
* <b>`sparse_indices`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`.
0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete
index where `sparse_values[i]` will be placed.
* <b>`output_shape`</b>: A `Tensor`. Must have the same type as `sparse_indices`.
1-D. Shape of the dense output tensor.
* <b>`sparse_values`</b>: A `Tensor`.
1-D. Values corresponding to each row of `sparse_indices`,
or a scalar value to be used for all sparse indices.
* <b>`default_value`</b>: A `Tensor`. Must have the same type as `sparse_values`.
Scalar value to set for indices not specified in
`sparse_indices`.
* <b>`sparse_indices`</b>: A 0-D, 1-D, or 2-D `Tensor` of type `int32` or `int64`.
`sparse_indices[i]` contains the complete index where `sparse_values[i]`
will be placed.
* <b>`output_shape`</b>: A 1-D `Tensor` of the same type as `sparse_indices`. Shape
of the dense output tensor.
* <b>`sparse_values`</b>: A 0-D or 1-D `Tensor`. Values corresponding to each row of
`sparse_indices`, or a scalar value to be used for all sparse indices.
* <b>`default_value`</b>: A 0-D `Tensor` of the same type as `sparse_values`. Value
to set for indices not specified in `sparse_indices`. Defaults to zero.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A `Tensor`. Has the same type as `sparse_values`.
Dense output tensor of shape `output_shape`.
Dense `Tensor` of shape `output_shape`. Has the same type as
`sparse_values`.
- - -
### `tf.sparse_tensor_to_dense(sp_input, default_value, name=None)` {#sparse_tensor_to_dense}
### `tf.sparse_tensor_to_dense(sp_input, default_value=0, name=None)` {#sparse_tensor_to_dense}
Converts a `SparseTensor` into a dense tensor.
......@@ -225,7 +223,7 @@ string tensor with values:
* <b>`sp_input`</b>: The input `SparseTensor`.
* <b>`default_value`</b>: Scalar value to set for indices not specified in
`sp_input`.
`sp_input`. Defaults to zero.
* <b>`name`</b>: A name prefix for the returned tensors (optional).
##### Returns:
......
......@@ -548,7 +548,7 @@ $ sudo easy_install -U six
[MacPorts](https://www.macports.org/) and re-install TensorFlow in that
copy of Python.
# Mac OS X: TypeError: `__init__()` got an unexpected keyword argument 'syntax'
### Mac OS X: TypeError: `__init__()` got an unexpected keyword argument 'syntax'
On Mac OS X, you may encounter the following when importing tensorflow.
......
......@@ -69,7 +69,7 @@ The code example below is a modification of the [simple MNIST tutorial]
added some summary ops, and run them every ten steps. If you run this and then
launch `tensorboard --logdir=/tmp/mnist_data`, you'll be able to visualize
statistics, such as how the weights or accuracy varied during training.
The code below is an exerpt; full source is [here](mnist_with_summaries.py).
The code below is an exerpt; full source is [here](../../tutorials/mnist/mnist_with_summaries.py).
```python
# Create the model
......
"""A very simple MNIST classifer, modified to display data in TensorBoard
See extensive documentation for the original model at
http://tensorflow.org/tutorials/mnist/beginners/index.md
See documentaion on the TensorBoard specific pieces at
http://tensorflow.org/how_tos/summaries_and_tensorboard/index.md
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import data
import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
import tensorflow as tf
sess = tf.InteractiveSession()
# Create the model
x = tf.placeholder("float", [None, 784], name="x-input")
W = tf.Variable(tf.zeros([784,10]), name="weights")
b = tf.Variable(tf.zeros([10], name="bias"))
# use a name scope to organize nodes in the graph visualizer
with tf.name_scope("Wx_b") as scope:
y = tf.nn.softmax(tf.matmul(x,W) + b)
# Add summary ops to collect data
w_hist = tf.histogram_summary("weights", W)
b_hist = tf.histogram_summary("biases", b)
y_hist = tf.histogram_summary("y", y)
# Define loss and optimizer
y_ = tf.placeholder("float", [None,10], name="y-input")
# More name scopes will clean up the graph representation
with tf.name_scope("xent") as scope:
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
ce_summ = tf.scalar_summary("cross entropy", cross_entropy)
with tf.name_scope("train") as scope:
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
with tf.name_scope("test") as scope:
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
accuracy_summary = tf.scalar_summary("accuracy", accuracy)
# Merge all the summaries and write them out to /tmp/mnist_logs
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph_def)
tf.initialize_all_variables().run()
# Train the model, and feed in test data and record summaries every 10 steps
for i in range(1000):
if i % 10 == 0: # Record summary data, and the accuracy
feed = {x: mnist.test.images, y_: mnist.test.labels}
result = sess.run([merged, accuracy], feed_dict=feed)
summary_str = result[0]
acc = result[1]
writer.add_summary(summary_str, i)
print("Accuracy at step %s: %s" % (i, acc))
else:
batch_xs, batch_ys = mnist.train.next_batch(100)
feed = {x: batch_xs, y_: batch_ys}
sess.run(train_step, feed_dict=feed)
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
......@@ -76,10 +76,11 @@ Isn't that bad? Well, the best computer vision methods do exploit this
structure, and we will in later tutorials. But the simple method we will be
using here, a softmax regression, won't.
The result is that `mnist.train.images` is a tensor (an n-dimensional array) with a
shape of `[60000, 784]`. The first dimension indexes the images and the second
dimension indexes the pixels in each image. Each entry in the tensor is the
pixel intensity between 0 and 1, for a particular pixel in a particular image.
The result is that `mnist.train.images` is a tensor (an n-dimensional array)
with a shape of `[55000, 784]`. The first dimension indexes the images and the
second dimension indexes the pixels in each image. Each entry in the tensor is
the pixel intensity between 0 and 1, for a particular pixel in a particular
image.
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/mnist-train-xs.png">
......@@ -89,11 +90,11 @@ The corresponding labels in MNIST are numbers between 0 and 9, describing
which digit a given image is of.
For the purposes of this tutorial, we're going to want our labels
as "one-hot vectors". A one-hot vector is a vector which is 0 in most
dimensions, and 1 in a single dimension. In this case, the \\(n\\)th digit will be
represented as a vector which is 1 in the \\(n\\)th dimensions. For example, 3
would be \\([0,0,0,1,0,0,0,0,0,0]\\).
dimensions, and 1 in a single dimension. In this case, the \\(n\\)th digit will
be represented as a vector which is 1 in the \\(n\\)th dimensions. For example,
3 would be \\([0,0,0,1,0,0,0,0,0,0]\\).
Consequently, `mnist.train.labels` is a
`[60000, 10]` array of floats.
`[55000, 10]` array of floats.
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="img/mnist-train-ys.png">
......
......@@ -91,12 +91,15 @@ def extract_labels(filename, one_hot=False):
class DataSet(object):
def __init__(self, images, labels, fake_data=False):
def __init__(self, images, labels, fake_data=False, one_hot=False):
"""Construct a DataSet. one_hot arg is used only if fake_data is true."""
if fake_data:
self._num_examples = 10000
self.one_hot = one_hot
else:
assert images.shape[0] == labels.shape[0], (
"images.shape: %s labels.shape: %s" % (images.shape,
'images.shape: %s labels.shape: %s' % (images.shape,
labels.shape))
self._num_examples = images.shape[0]
......@@ -132,8 +135,11 @@ class DataSet(object):
def next_batch(self, batch_size, fake_data=False):
"""Return the next `batch_size` examples from this data set."""
if fake_data:
fake_image = [1.0 for _ in xrange(784)]
fake_label = 0
fake_image = [1] * 784
if self.one_hot:
fake_label = [1] + [0] * 9
else:
fake_label = 0
return [fake_image for _ in xrange(batch_size)], [
fake_label for _ in xrange(batch_size)]
start = self._index_in_epoch
......@@ -160,9 +166,9 @@ def read_data_sets(train_dir, fake_data=False, one_hot=False):
data_sets = DataSets()
if fake_data:
data_sets.train = DataSet([], [], fake_data=True)
data_sets.validation = DataSet([], [], fake_data=True)
data_sets.test = DataSet([], [], fake_data=True)
data_sets.train = DataSet([], [], fake_data=True, one_hot=one_hot)
data_sets.validation = DataSet([], [], fake_data=True, one_hot=one_hot)
data_sets.test = DataSet([], [], fake_data=True, one_hot=one_hot)
return data_sets
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
......
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A very simple MNIST classifer, modified to display data in TensorBoard.
See extensive documentation for the original model at
http://tensorflow.org/tutorials/mnist/beginners/index.md
See documentaion on the TensorBoard specific pieces at
http://tensorflow.org/how_tos/summaries_and_tensorboard/index.md
If you modify this file, please update the exerpt in
how_tos/summaries_and_tensorboard/index.md.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
from tensorflow.g3doc.tutorials.mnist import input_data
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
'for unit testing.')
flags.DEFINE_integer('max_steps', 1000, 'Number of steps to run trainer.')
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
def main(_):
# Import data
mnist = input_data.read_data_sets('/tmp/data/', one_hot=True,
fake_data=FLAGS.fake_data)
sess = tf.InteractiveSession()
# Create the model
x = tf.placeholder('float', [None, 784], name='x-input')
W = tf.Variable(tf.zeros([784, 10]), name='weights')
b = tf.Variable(tf.zeros([10], name='bias'))
# use a name scope to organize nodes in the graph visualizer
with tf.name_scope('Wx_b') as scope:
y = tf.nn.softmax(tf.matmul(x, W) + b)
# Add summary ops to collect data
w_hist = tf.histogram_summary('weights', W)
b_hist = tf.histogram_summary('biases', b)
y_hist = tf.histogram_summary('y', y)
# Define loss and optimizer
y_ = tf.placeholder('float', [None, 10], name='y-input')
# More name scopes will clean up the graph representation
with tf.name_scope('xent') as scope:
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
ce_summ = tf.scalar_summary('cross entropy', cross_entropy)
with tf.name_scope('train') as scope:
train_step = tf.train.GradientDescentOptimizer(
FLAGS.learning_rate).minimize(cross_entropy)
with tf.name_scope('test') as scope:
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
accuracy_summary = tf.scalar_summary('accuracy', accuracy)
# Merge all the summaries and write them out to /tmp/mnist_logs
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph_def)
tf.initialize_all_variables().run()
# Train the model, and feed in test data and record summaries every 10 steps
for i in range(FLAGS.max_steps):
if i % 10 == 0: # Record summary data, and the accuracy
if FLAGS.fake_data:
batch_xs, batch_ys = mnist.train.next_batch(
100, fake_data=FLAGS.fake_data)
feed = {x: batch_xs, y_: batch_ys}
else:
feed = {x: mnist.test.images, y_: mnist.test.labels}
result = sess.run([merged, accuracy], feed_dict=feed)
summary_str = result[0]
acc = result[1]
writer.add_summary(summary_str, i)
print('Accuracy at step %s: %s' % (i, acc))
else:
batch_xs, batch_ys = mnist.train.next_batch(
100, fake_data=FLAGS.fake_data)
feed = {x: batch_xs, y_: batch_ys}
sess.run(train_step, feed_dict=feed)
if __name__ == '__main__':
tf.app.run()
......@@ -147,7 +147,7 @@ $$J_\text{NEG} = \log Q_\theta(D=1 |w_t, h) +
where \\(Q_\theta(D=1 | w, h)\\) is the binary logistic regression probability
under the model of seeing the word \\(w\\) in the context \\(h\\) in the dataset
\\(D\\), calculated in terms of the learned embedding vectors \\(\theta\\). In
practice we approximate the expectation by drawing \\(k\\) constrastive words
practice we approximate the expectation by drawing \\(k\\) contrastive words
from the noise distribution (i.e. we compute a
[Monte Carlo average](https://en.wikipedia.org/wiki/Monte_Carlo_integration)).
......
......@@ -487,6 +487,7 @@ tf_gen_op_wrapper_py(
name = "random_ops",
hidden = [
"RandomUniform",
"RandomUniformInt",
"RandomShuffle",
"RandomStandardNormal",
"TruncatedNormal",
......@@ -510,6 +511,7 @@ tf_gen_op_wrapper_py(
"SparseConcat",
"SparseSelectLastK",
"SparseReorder",
"SparseToDense",
],
require_shape_functions = True,
)
......
......@@ -168,20 +168,23 @@ class RandomUniformTest(tf.test.TestCase):
return func
def testRange(self):
for use_gpu in [False, True]:
for dt in tf.float32, tf.float64:
sampler = self._Sampler(1000, -2., 8., dt, use_gpu=use_gpu)
for use_gpu in False, True:
for dt in tf.float32, tf.float64, tf.int32, tf.int64:
sampler = self._Sampler(1000, minv=-2, maxv=8, dtype=dt,
use_gpu=use_gpu)
x = sampler()
self.assertTrue(-2 <= np.min(x))
self.assertTrue(np.max(x) <= 8)
self.assertTrue(np.max(x) < 8)
# Asserts that different trials (1000 samples per trial) is unlikely
# to see the same sequence of values. Will catch buggy
# implementations which uses the same random number seed.
def testDistinct(self):
for use_gpu in [False, True]:
for dt in tf.float32, tf.float64:
sampler = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu)
for use_gpu in False, True:
for dt in tf.float32, tf.float64, tf.int32, tf.int64:
maxv = 1.0 if dt.is_floating else 1 << 30
sampler = self._Sampler(1000, minv=0, maxv=maxv, dtype=dt,
use_gpu=use_gpu)
x = sampler()
y = sampler()
count = (x == y).sum()
......@@ -191,33 +194,57 @@ class RandomUniformTest(tf.test.TestCase):
print("count = ", count)
self.assertTrue(count < 10)
# Check that uniform ints actually follow a uniform distribution.
def testUniformInts(self):
minv = -2
maxv = 15
n = 100000
p = 1 / (maxv - minv)
# The counts should follow an (n, p) binomial distribution.
mean = p * n
std = np.sqrt(n * p * (1 - p))
for use_gpu in False, True:
for dt in tf.int32, tf.int64:
# Use a fixed seed here to make the test deterministic.
# Without the fixed seed, the 5 * std bound will (very rarely) fail.
sampler = self._Sampler(n // 10, minv=minv, maxv=maxv, dtype=dt,
use_gpu=use_gpu, seed=17)
x = sampler().ravel()
self.assertEqual(x.shape, (n,))
counts, _ = np.histogram(x, bins=maxv - minv)
self.assertEqual(counts.shape, (maxv - minv,))
self.assertEqual(counts.sum(), n)
error = np.abs(counts - mean)
self.assertLess(error.max(), 5 * std)
# Checks that the CPU and GPU implementation returns the same results,
# given the same random seed
def testCPUGPUMatch(self):
for dt in tf.float32, tf.float64:
for dt in tf.float32, tf.float64, tf.int32, tf.int64:
maxv = 1.0 if dt.is_floating else 17
results = {}
for use_gpu in [False, True]:
sampler = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=12345)
for use_gpu in False, True:
sampler = self._Sampler(1000, minv=0, maxv=maxv, dtype=dt,
use_gpu=use_gpu, seed=12345)
results[use_gpu] = sampler()
self.assertAllClose(results[False], results[True], rtol=1e-6, atol=1e-6)
self.assertAllEqual(results[False], results[True])
def testSeed(self):
for use_gpu in [False, True]:
for dt in tf.float32, tf.float64:
sx = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
sy = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
for use_gpu in False, True:
for dt in tf.float32, tf.float64, tf.int32, tf.int64:
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
self.assertAllEqual(sx(), sy())
def testNoCSE(self):
for use_gpu in [False, True]:
with self.test_session(use_gpu=use_gpu):
shape = [2, 3, 4]
rnd1 = tf.random_uniform(shape, 0.0, 1.0,
dtype=tf.float32)
rnd2 = tf.random_uniform(shape, 0.0, 1.0,
dtype=tf.float32)
diff = (rnd2 - rnd1).eval()
self.assertTrue(np.linalg.norm(diff) > 0.1)
shape = [2, 3, 4]
for use_gpu in False, True:
for dtype in tf.float32, tf.int32:
with self.test_session(use_gpu=use_gpu):
rnd1 = tf.random_uniform(shape, 0, 17, dtype=dtype)
rnd2 = tf.random_uniform(shape, 0, 17, dtype=dtype)
diff = (rnd2 - rnd1).eval()
self.assertTrue(np.linalg.norm(diff) > 0.1)
class RandomShapeTest(tf.test.TestCase):
......
......@@ -216,13 +216,14 @@ class ShapeOpsTest(tf.test.TestCase):
class TileTest(tf.test.TestCase):
def testScalar(self):
with self.test_session():
a = tf.constant(7, shape=[], dtype=tf.float32)
tiled = tf.tile(a, [])
result = tiled.eval()
self.assertEqual(result.shape, ())
self.assertEqual([], tiled.get_shape())
self.assertEqual(7, result)
for use_gpu in False, True:
with self.test_session(use_gpu=use_gpu):
a = tf.constant(7, shape=[], dtype=tf.float32)
tiled = tf.tile(a, [])
result = tiled.eval()
self.assertEqual(result.shape, ())
self.assertEqual([], tiled.get_shape())
self.assertEqual(7, result)
def testSimple(self):
with self.test_session():
......@@ -357,20 +358,23 @@ class TileTest(tf.test.TestCase):
self.assertAllClose(expected, result, 1e-3)
def _RunAndVerifyGradientResult(self, input_shape, multiples):
with self.test_session():
# Random values
inp = np.random.rand(*input_shape)
a = tf.constant([float(x) for x in inp.flatten()],
shape=input_shape, dtype=tf.float64)
tiled = tf.tile(a, multiples)
grad_shape = list(np.array(multiples) * np.array(inp.shape))
err = tf.test.compute_gradient_error(a,
list(input_shape),
tiled,
grad_shape,
x_init_value=inp)
print("tile(float) error = ", err)
self.assertLess(err, 1e-3)
for use_gpu in False, True:
with self.test_session(use_gpu=use_gpu):
# Random values
inp = np.asarray(np.random.rand(*input_shape))
a = tf.constant(inp, dtype=tf.float64)
tiled = tf.tile(a, multiples)
grad_shape = list(np.array(multiples) * np.array(inp.shape))
err = tf.test.compute_gradient_error(a,
list(input_shape),
tiled,
grad_shape,
x_init_value=inp)
print("tile(float) error = ", err)
self.assertLess(err, 1e-3)
def testGradientRandomScalar(self):
self._RunAndVerifyGradientResult([], [])
def testGradientRandom(self):
self._RunAndVerifyGradientResult([2, 2, 1, 1, 3], [1, 2, 1, 3, 1])
......
......@@ -71,6 +71,11 @@ class SparseToDenseTest(tf.test.TestCase):
[ 1, -1, -1, -1]]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans)
def testZeroDefault(self):
with self.test_session():
x = tf.sparse_to_dense(2, [4], 7).eval()
self.assertAllEqual(x, [0, 0, 7, 0])
def test3d(self):
with self.test_session(use_gpu=False):
tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1).eval()
......
......@@ -122,7 +122,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
ops.NoGradient("TruncatedNormal")
def random_uniform(shape, minval=0.0, maxval=1.0,
def random_uniform(shape, minval=0, maxval=None,
dtype=dtypes.float32, seed=None,
name=None):
"""Outputs random values from a uniform distribution.
......@@ -131,13 +131,22 @@ def random_uniform(shape, minval=0.0, maxval=1.0,
`[minval, maxval)`. The lower bound `minval` is included in the range, while
the upper bound `maxval` is excluded.
For floats, the default range is `[0, 1)`. For ints, at least `maxval` must
be specified explicitly.
In the integer case, the random integers are slightly biased unless
`maxval - minval` is an exact power of two. The bias is small for values of
`maxval - minval` significantly smaller than the range of the output (either
`2**32` or `2**64`).
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
range of random values to generate.
range of random values to generate. Defaults to 0.
maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on
the range of random values to generate.
dtype: The type of the output.
the range of random values to generate. Defaults to 1 if `dtype` is
floating point.
dtype: The type of the output: `float32`, `float64`, `int32`, or `int64`.
seed: A Python integer. Used to create a random seed for the distribution.
See
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
......@@ -146,19 +155,28 @@ def random_uniform(shape, minval=0.0, maxval=1.0,
Returns:
A tensor of the specified shape filled with random uniform values.
Raises:
ValueError: If `dtype` is integral and `maxval` is not specified.
"""
dtype = dtypes.as_dtype(dtype)
if maxval is None:
if dtype.is_integer:
raise ValueError("Must specify maxval for integer dtype %r" % dtype)
maxval = 1
with ops.op_scope([shape, minval, maxval], name, "random_uniform") as name:
shape_tensor = _ShapeTensor(shape)
min_tensor = ops.convert_to_tensor(minval, dtype=dtype, name="min")
range_tensor = ops.convert_to_tensor(
maxval - minval, dtype=dtype, name="range")
shape = _ShapeTensor(shape)
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
seed1, seed2 = random_seed.get_seed(seed)
rnd = gen_random_ops._random_uniform(shape_tensor, dtype,
seed=seed1,
seed2=seed2)
mul = rnd * range_tensor
value = math_ops.add(mul, min_tensor, name=name)
return value
if dtype.is_integer:
return gen_random_ops._random_uniform_int(shape, minval, maxval,
seed=seed1, seed2=seed2,
name=name)
else:
rnd = gen_random_ops._random_uniform(shape, dtype, seed=seed1,
seed2=seed2)
return math_ops.add(rnd * (maxval - minval), minval, name=name)
def random_shuffle(value, seed=None, name=None):
......@@ -197,6 +215,7 @@ ops.NoGradient("RandomUniform")
@ops.RegisterShape("TruncatedNormal")
@ops.RegisterShape("RandomStandardNormal")
@ops.RegisterShape("RandomUniform")
@ops.RegisterShape("RandomUniformInt")
def _RandomShape(op):
shape_val = tensor_util.ConstantValue(op.inputs[0])
if shape_val is not None:
......
......@@ -240,7 +240,48 @@ def _SparseToDenseShape(op):
return [tensor_shape.unknown_shape(ndims=input_shape_shape.num_elements())]
def sparse_tensor_to_dense(sp_input, default_value, name=None):
def sparse_to_dense(sparse_indices, output_shape, sparse_values,
default_value=0, name=None):
"""Converts a sparse representation into a dense tensor.
Builds an array `dense` with shape `output_shape` such that
```python
# If sparse_indices is scalar
dense[i] = (i == sparse_indices ? sparse_values : default_value)
# If sparse_indices is a vector, then for each i
dense[sparse_indices[i]] = sparse_values[i]
# If sparse_indices is an n by d matrix, then for each i in [0, n)
dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
```
All other values in `dense` are set to `default_value`. If `sparse_values`
is a scalar, all sparse indices are set to this single value.
Args:
sparse_indices: A 0-D, 1-D, or 2-D `Tensor` of type `int32` or `int64`.
`sparse_indices[i]` contains the complete index where `sparse_values[i]`
will be placed.
output_shape: A 1-D `Tensor` of the same type as `sparse_indices`. Shape
of the dense output tensor.
sparse_values: A 0-D or 1-D `Tensor`. Values corresponding to each row of
`sparse_indices`, or a scalar value to be used for all sparse indices.
default_value: A 0-D `Tensor` of the same type as `sparse_values`. Value
to set for indices not specified in `sparse_indices`. Defaults to zero.
name: A name for the operation (optional).
Returns:
Dense `Tensor` of shape `output_shape`. Has the same type as
`sparse_values`.
"""
return gen_sparse_ops._sparse_to_dense(sparse_indices, output_shape,
sparse_values, default_value,
name=name)
def sparse_tensor_to_dense(sp_input, default_value=0, name=None):
"""Converts a `SparseTensor` into a dense tensor.
This op is a convenience wrapper around `sparse_to_dense` for `SparseTensor`s.
......@@ -261,7 +302,7 @@ def sparse_tensor_to_dense(sp_input, default_value, name=None):
Args:
sp_input: The input `SparseTensor`.
default_value: Scalar value to set for indices not specified in
`sp_input`.
`sp_input`. Defaults to zero.
name: A name prefix for the returned tensors (optional).
Returns:
......@@ -275,12 +316,8 @@ def sparse_tensor_to_dense(sp_input, default_value, name=None):
if not isinstance(sp_input, ops.SparseTensor):
raise TypeError("Input must be a SparseTensor")
return gen_sparse_ops.sparse_to_dense(
sp_input.indices,
sp_input.shape,
sp_input.values,
default_value,
name=name)
return sparse_to_dense(sp_input.indices, sp_input.shape, sp_input.values,
default_value, name=name)
def sparse_to_indicator(sp_input, vocab_size, name=None):
......@@ -455,7 +492,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
all_row_indices = math_ops.cast(math_ops.range(num_rows), dtypes.int64)
empty_row_indices, _ = array_ops.list_diff(
all_row_indices, sp_input.indices[:, 0])
empty_row_indicator = gen_sparse_ops.sparse_to_dense(
empty_row_indicator = sparse_to_dense(
empty_row_indices, array_ops.expand_dims(sp_input.shape[0], -1), True,
False)
......
......@@ -269,7 +269,7 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
Sharing a variable by capturing a scope and setting reuse:
```python
with tf.variable_scope("foo") as scope.
with tf.variable_scope("foo") as scope:
v = tf.get_variable("v", [1])
scope.reuse_variables()
v1 = tf.get_variable("v", [1])
......@@ -280,7 +280,7 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
getting an existing variable in a non-reusing scope.
```python
with tf.variable_scope("foo") as scope.
with tf.variable_scope("foo"):
v = tf.get_variable("v", [1])
v1 = tf.get_variable("v", [1])
# Raises ValueError("... v already exists ...").
......
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/**
* Simple server for running TensorBoard during development.
*/
......
......@@ -4,6 +4,8 @@ cc_library(
name = "eigen3",
hdrs = glob([
"**/*.h",
"unsupported/Eigen/CXX11/*",
"Eigen/*",
]),
includes = [ "." ],
visibility = ["//visibility:public"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册