提交 ddd4aaf5 编写于 作者: V Vijay Vasudevan

TensorFlow: upstream changes to git.

Change 109695551
	Update FAQ
Change 109694725
	Add a gradient for resize_bilinear op.
Change 109694505
	Don't mention variables module in docs

	variables.Variable should be tf.Variable.
Change 109658848
	Adding an option to create a new thread-pool for each session.
Change 109640570

	Take the snapshot of stream-executor.
	+ Expose an interface for scratch space allocation in the interface.

Change 109638559
	Let image_summary accept uint8 input

	This allows users to do their own normalization / scaling if the default
	(very weird) behavior of image_summary is undesired.

	This required a slight tweak to fake_input.cc to make polymorphically typed
	fake inputs infer if their type attr is not set but has a default.

	Unfortunately, adding a second valid type to image_summary *disables* automatic
	implicit conversion from np.float64 to tf.float32, so this change is slightly
	backwards incompatible.
Change 109636969
	Add serialization operations for SparseTensor.
Change 109636644
	Update generated Op docs.
Change 109634899
	TensorFlow: add a markdown file for producing release notes for our
	releases.  Seed with 0.5.0 with a boring but accurate description.
Change 109634502
	Let histogram_summary take any realnumbertype

	It used to take only floats, not it understands ints.
Change 109634434
	TensorFlow: update locations where we mention python 3 support, update
	them to current truth.
Change 109632108
	Move HSV <> RGB conversions, grayscale conversions, and adjust_* ops back to tensorflow
	- make GPU-capable version of RGBToHSV and HSVToRGB, allows only float input/output
	- change docs to reflect new size constraints
	- change HSV format to be [0,1] for all components
	- add automatic dtype conversion for all adjust_* and grayscale conversion ops
	- fix up docs
Change 109631077
	Improve optimizer exceptions

	1. grads_and_vars is now a tuple, so must be wrapped when passed to format.
	2. Use '%r' instead of '%s' for dtype formatting

Base CL: 109697989
上级 cd53f3c3
......@@ -31,9 +31,7 @@ installing from source, GPU-enabled support, etc., see
## Binary Installation
The TensorFlow Python API currently requires Python 2.7: we are
[working](https://github.com/tensorflow/tensorflow/issues/1) on adding support
for Python 3.
The TensorFlow Python API supports Python 2.7 and Python 3.3+.
The simplest way to install TensorFlow is using
[pip](https://pypi.python.org/pypi/pip) for both Linux and Mac.
......
# Release 0.5.0
Initial release of TensorFlow.
......@@ -48,8 +48,7 @@ namespace tensorflow {
namespace {
thread::ThreadPool* kernel_thread_pool_ = nullptr;
static bool InitModule(const SessionOptions& options) {
thread::ThreadPool* NewThreadPool(const SessionOptions& options) {
int32 inter_op_parallelism_threads =
options.config.inter_op_parallelism_threads();
if (inter_op_parallelism_threads == 0) {
......@@ -58,9 +57,13 @@ static bool InitModule(const SessionOptions& options) {
}
LOG(INFO) << "Direct session inter op parallelism threads: "
<< inter_op_parallelism_threads;
kernel_thread_pool_ = new thread::ThreadPool(options.env, "Compute",
inter_op_parallelism_threads);
return true;
return new thread::ThreadPool(options.env, "Compute",
inter_op_parallelism_threads);
}
thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
static thread::ThreadPool* const thread_pool = NewThreadPool(options);
return thread_pool;
}
// TODO(vrv): Figure out how to unify the many different functions
......@@ -75,6 +78,8 @@ string GetRendezvousKey(const string& tensor_name,
frame_iter.frame_id, ":", frame_iter.iter_id);
}
} // namespace
// NOTE: On Android with a single device, there is never
// a risk of an OpKernel blocking indefinitely:
//
......@@ -90,7 +95,7 @@ string GetRendezvousKey(const string& tensor_name,
// This may change down the road when we add support for multiple
// devices that run concurrently, in which case we will need to
// revisit this decision.
void SchedClosure(std::function<void()> c) {
void DirectSession::SchedClosure(std::function<void()> c) {
// TODO(sanjay): Get rid of __ANDROID__ path
#ifdef __ANDROID__
// On Android, there is no implementation of ThreadPool that takes
......@@ -100,19 +105,20 @@ void SchedClosure(std::function<void()> c) {
// safe given the reasoning above.
c();
#else
kernel_thread_pool_->Schedule(c);
thread_pool_->Schedule(c);
#endif // __ANDROID__
}
} // namespace
DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr)
: options_(options),
device_mgr_(device_mgr),
cancellation_manager_(new CancellationManager()) {
static bool init = InitModule(options);
CHECK(init); // Avoids compiler warning that init is unused.
if (options_.config.use_per_session_threads()) {
thread_pool_ = NewThreadPool(options_);
} else {
thread_pool_ = GlobalThreadPool(options);
}
// NOTE(mrry): We do not need to use a unique string for the session
// handle, because DirectSession owns its devices. This may change
// in future versions.
......@@ -149,6 +155,10 @@ DirectSession::~DirectSession() {
delete it.second;
}
delete cancellation_manager_;
if (options_.config.use_per_session_threads()) {
delete thread_pool_;
}
}
Status DirectSession::Create(const GraphDef& graph) {
......@@ -230,7 +240,7 @@ Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
Executor::Args args;
args.rendezvous = rendez;
args.cancellation_manager = cancellation_manager_;
args.runner = SchedClosure;
args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); };
for (auto device_executor : executors_and_keys->device_executors) {
Executor* exec = device_executor.second;
......
......@@ -36,6 +36,7 @@ limitations under the License.
namespace tensorflow {
class Device;
class ThreadPool;
class DirectSession : public Session {
public:
......@@ -94,6 +95,12 @@ class DirectSession : public Session {
mutex graph_def_lock_;
GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
// The thread-pool to use for running ops.
thread::ThreadPool* thread_pool_ = nullptr;
// Schedules 'c' for execution.
void SchedClosure(std::function<void()> c);
mutex executor_lock_; // protects executors_
// Holds mappings from signature to the executors that process
// it. The reason for a level of indirection around mapped_type is
......
......@@ -162,6 +162,43 @@ TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
delete tp;
}
TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) {
Initialize({1, 2, 3, 4});
SessionOptions options;
options.config.set_use_per_session_threads(true);
(*options.config.mutable_device_count())["CPU"] = 2;
std::unique_ptr<Session> session(NewSession(options));
ASSERT_TRUE(session != nullptr);
ASSERT_OK(session->Create(def_));
// Fill in the input and ask for the output
thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
// Run the graph 1000 times in 4 different threads concurrently.
std::vector<string> output_names = {y_ + ":0"};
auto fn = [&session, output_names]() {
for (int i = 0; i < 1000; ++i) {
std::vector<std::pair<string, Tensor>> inputs;
std::vector<Tensor> outputs;
// Run the graph
Status s = session->Run(inputs, output_names, {}, &outputs);
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
auto mat = outputs[0].matrix<float>();
EXPECT_FLOAT_EQ(3.0, mat(0, 0));
}
};
for (int i = 0; i < 4; ++i) {
tp->Schedule(fn);
}
// Wait for the functions to finish.
delete tp;
}
TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
Initialize({1, 2, 3, 4});
std::unique_ptr<Session> session(CreateSession());
......
......@@ -41,9 +41,16 @@ message ConfigProto {
// 0 means the system picks an appropriate number.
//
// Note that the first Session created in the process sets the
// number of threads for all future sessions.
// number of threads for all future sessions unless use_per_session_threads is
// true.
int32 inter_op_parallelism_threads = 5;
// If true, use a new set of threads for this session rather than the global
// pool of threads. Only supported by direct sessions.
//
// If false, use the global threads created by the first session.
bool use_per_session_threads = 9;
// Assignment of Nodes to Devices is recomputed every placement_period
// steps until the system warms up (at which point the recomputation
// typically slows down automatically).
......
......@@ -144,6 +144,12 @@ Status FakeInputImpl::GetDataType(DataType* dt) const {
} else if (!arg_->type_attr().empty()) {
Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt);
if (!status.ok()) {
// Check if the type attr has a default
const OpDef::AttrDef* attr = FindAttr(arg_->type_attr(), *op_def_);
if (attr && attr->has_default_value()) {
*dt = attr->default_value().type();
return Status::OK();
}
return errors::InvalidArgument("Could not infer type for input '",
arg_->name(), "': ",
status.error_message());
......
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/array_ops.cc.
#define EIGEN_USE_THREADS
#include <algorithm>
#include <cmath>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/colorspace_op.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device>
class RGBToHSVOp : public OpKernel {
public:
explicit RGBToHSVOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
OP_REQUIRES(context, input.dims() >= 1,
errors::InvalidArgument("input must be at least 1D",
input.shape().ShortDebugString()));
auto channels = input.dim_size(input.dims() - 1);
OP_REQUIRES(context, channels == 3,
errors::FailedPrecondition(
"input must have 3 channels but input only has ", channels,
" channels."));
// Create the output Tensor with the same dimensions as the input Tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
// Make a canonical image, maintaining the last (channel) dimension, while
// flattening all others do give the functor easy to work with data.
TTypes<float, 2>::ConstTensor input_data = input.flat_inner_dims<float>();
TTypes<float, 2>::Tensor output_data = output->flat_inner_dims<float>();
Tensor trange;
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<float>::value,
TensorShape({input_data.dimension(0)}),
&trange));
TTypes<float, 1>::Tensor range = trange.tensor<float, 1>();
functor::RGBToHSV<Device>()(context->eigen_device<Device>(), input_data,
range, output_data);
}
};
template <typename Device>
class HSVToRGBOp : public OpKernel {
public:
explicit HSVToRGBOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
OP_REQUIRES(context, input.dims() >= 1,
errors::InvalidArgument("input must be at least 1D",
input.shape().ShortDebugString()));
auto channels = input.dim_size(input.dims() - 1);
OP_REQUIRES(context, channels == 3,
errors::FailedPrecondition(
"input must have 3 channels but input only has ", channels,
" channels."));
// Create the output Tensor with the same dimensions as the input Tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
TTypes<float, 2>::ConstTensor input_data = input.flat_inner_dims<float>();
TTypes<float, 2>::Tensor output_data = output->flat_inner_dims<float>();
functor::HSVToRGB<Device>()(context->eigen_device<Device>(), input_data,
output_data);
}
};
REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_CPU),
RGBToHSVOp<CPUDevice>);
template class RGBToHSVOp<CPUDevice>;
REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_CPU),
HSVToRGBOp<CPUDevice>);
template class HSVToRGBOp<CPUDevice>;
#if GOOGLE_CUDA
// Forward declarations of the function specializations for GPU (to prevent
// building the GPU versions here, they will be built compiling _gpu.cu.cc).
namespace functor {
template <>
void RGBToHSV<GPUDevice>::operator()(const GPUDevice& d,
TTypes<float, 2>::ConstTensor input_data,
TTypes<float, 1>::Tensor range,
TTypes<float, 2>::Tensor output_data);
extern template struct RGBToHSV<GPUDevice>;
template <>
void HSVToRGB<GPUDevice>::operator()(const GPUDevice& d,
TTypes<float, 2>::ConstTensor input_data,
TTypes<float, 2>::Tensor output_data);
extern template struct HSVToRGB<GPUDevice>;
} // namespace functor
REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_GPU),
RGBToHSVOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_GPU),
HSVToRGBOp<GPUDevice>);
#endif
} // namespace tensorflow
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_COLORSPACE_OP_H_
#define TENSORFLOW_KERNELS_COLORSPACE_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/public/tensor_shape.h"
namespace tensorflow {
namespace functor {
template <typename Device>
struct RGBToHSV {
void operator()(const Device &d, TTypes<float, 2>::ConstTensor input_data,
TTypes<float, 1>::Tensor range,
TTypes<float, 2>::Tensor output_data) {
auto H = output_data.chip<1>(0);
auto S = output_data.chip<1>(1);
auto V = output_data.chip<1>(2);
auto R = input_data.chip<1>(0);
auto G = input_data.chip<1>(1);
auto B = input_data.chip<1>(2);
#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::array<int, 1> channel_axis{{1}};
#else
Eigen::IndexList<Eigen::type2index<1> > channel_axis;
#endif
V.device(d) = input_data.maximum(channel_axis);
range.device(d) = V - input_data.minimum(channel_axis);
S.device(d) = (V > 0.f).select(range / V, V.constant(0.f));
auto norm = range.inverse() * (1.f / 6.f);
// TODO(wicke): all these assignments are only necessary because a combined
// expression is larger than kernel parameter space. A custom kernel is
// probably in order.
H.device(d) = (R == V).select(norm * (G - B),
(G == V).select(norm * (B - R) + 2.f / 6.f,
norm * (R - G) + 4.f / 6.f));
H.device(d) = (range > 0.f).select(H, H.constant(0.f));
H.device(d) = (H < 0.f).select(H + 1.f, H);
}
};
template <typename Device>
struct HSVToRGB {
void operator()(const Device &d, TTypes<float, 2>::ConstTensor input_data,
TTypes<float, 2>::Tensor output_data) {
auto H = input_data.chip<1>(0);
auto S = input_data.chip<1>(1);
auto V = input_data.chip<1>(2);
// TODO(wicke): compute only the fractional part of H for robustness
auto dh = H * 6.f;
auto dr = ((dh - 3.f).abs() - 1.f).cwiseMax(0.f).cwiseMin(1.f);
auto dg = (-(dh - 2.f).abs() + 2.f).cwiseMax(0.f).cwiseMin(1.f);
auto db = (-(dh - 4.f).abs() + 2.f).cwiseMax(0.f).cwiseMin(1.f);
auto one_s = -S + 1.f;
auto R = output_data.chip<1>(0);
auto G = output_data.chip<1>(1);
auto B = output_data.chip<1>(2);
R.device(d) = (one_s + S * dr) * V;
G.device(d) = (one_s + S * dg) * V;
B.device(d) = (one_s + S * db) * V;
}
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_COLORSPACE_OP_H_
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/colorspace_op.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
template class functor::RGBToHSV<GPUDevice>;
template class functor::HSVToRGB<GPUDevice>;
}
#endif // GOOGLE_CUDA
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/allocator.h"
#include <gtest/gtest.h>
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/public/tensor.h"
namespace tensorflow {
class RGBToHSVOpTest : public OpsTestBase {
protected:
RGBToHSVOpTest() {
RequireDefaultOps();
EXPECT_OK(NodeDefBuilder("rgb_to_hsv_op", "RGBToHSV")
.Input(FakeInput(DT_FLOAT))
.Finalize(node_def()));
EXPECT_OK(InitOp());
}
};
TEST_F(RGBToHSVOpTest, CheckBlack) {
// Black pixel should map to hsv = [0,0,0]
AddInputFromArray<float>(TensorShape({3}), {0, 0, 0});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {0.0, 0.0, 0.0});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(RGBToHSVOpTest, CheckGray) {
// Gray pixel should have hue = saturation = 0.0, value = r/255
AddInputFromArray<float>(TensorShape({3}), {.5, .5, .5});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {0.0, 0.0, .5});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(RGBToHSVOpTest, CheckWhite) {
// Gray pixel should have hue = saturation = 0.0, value = 1.0
AddInputFromArray<float>(TensorShape({3}), {1, 1, 1});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {0.0, 0.0, 1.0});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(RGBToHSVOpTest, CheckRedMax) {
// Test case where red channel dominates
AddInputFromArray<float>(TensorShape({3}), {.8, .4, .2});
ASSERT_OK(RunOpKernel());
float expected_h = 1. / 6. * .2 / .6;
float expected_s = .6 / .8;
float expected_v = .8 / 1.;
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {expected_h, expected_s, expected_v});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-6);
}
TEST_F(RGBToHSVOpTest, CheckGreenMax) {
// Test case where green channel dominates
AddInputFromArray<float>(TensorShape({3}), {.2, .8, .4});
ASSERT_OK(RunOpKernel());
float expected_h = 1. / 6. * (2.0 + (.2 / .6));
float expected_s = .6 / .8;
float expected_v = .8 / 1.;
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {expected_h, expected_s, expected_v});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-6);
}
TEST_F(RGBToHSVOpTest, CheckBlueMax) {
// Test case where blue channel dominates
AddInputFromArray<float>(TensorShape({3}), {.4, .2, .8});
ASSERT_OK(RunOpKernel());
float expected_h = 1. / 6. * (4.0 + (.2 / .6));
float expected_s = .6 / .8;
float expected_v = .8 / 1.;
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {expected_h, expected_s, expected_v});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-6);
}
TEST_F(RGBToHSVOpTest, CheckNegativeDifference) {
AddInputFromArray<float>(TensorShape({3}), {0, .1, .2});
ASSERT_OK(RunOpKernel());
float expected_h = 1. / 6. * (4.0 + (-.1 / .2));
float expected_s = .2 / .2;
float expected_v = .2 / 1.;
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {expected_h, expected_s, expected_v});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-6);
}
class HSVToRGBOpTest : public OpsTestBase {
protected:
HSVToRGBOpTest() {
RequireDefaultOps();
EXPECT_OK(NodeDefBuilder("hsv_to_rgb_op", "HSVToRGB")
.Input(FakeInput(DT_FLOAT))
.Finalize(node_def()));
EXPECT_OK(InitOp());
}
};
TEST_F(HSVToRGBOpTest, CheckBlack) {
// Black pixel should map to rgb = [0,0,0]
AddInputFromArray<float>(TensorShape({3}), {0.0, 0.0, 0.0});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {0, 0, 0});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(HSVToRGBOpTest, CheckGray) {
// Gray pixel should have hue = saturation = 0.0, value = r/255
AddInputFromArray<float>(TensorShape({3}), {0.0, 0.0, .5});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {.5, .5, .5});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(HSVToRGBOpTest, CheckWhite) {
// Gray pixel should have hue = saturation = 0.0, value = 1.0
AddInputFromArray<float>(TensorShape({3}), {0.0, 0.0, 1.0});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {1, 1, 1});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(HSVToRGBOpTest, CheckRedMax) {
// Test case where red channel dominates
float expected_h = 1. / 6. * .2 / .6;
float expected_s = .6 / .8;
float expected_v = .8 / 1.;
AddInputFromArray<float>(TensorShape({3}),
{expected_h, expected_s, expected_v});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {.8, .4, .2});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-6);
}
TEST_F(HSVToRGBOpTest, CheckGreenMax) {
// Test case where green channel dominates
float expected_h = 1. / 6. * (2.0 + (.2 / .6));
float expected_s = .6 / .8;
float expected_v = .8 / 1.;
AddInputFromArray<float>(TensorShape({3}),
{expected_h, expected_s, expected_v});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {.2, .8, .4});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-6);
}
TEST_F(HSVToRGBOpTest, CheckBlueMax) {
// Test case where blue channel dominates
float expected_h = 1. / 6. * (4.0 + (.2 / .6));
float expected_s = .6 / .8;
float expected_v = .8 / 1.0;
AddInputFromArray<float>(TensorShape({3}),
{expected_h, expected_s, expected_v});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected, {.4, .2, .8});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-6);
}
} // namespace tensorflow
......@@ -107,6 +107,102 @@ class ResizeBilinearOp : public OpKernel {
}
};
template <typename Device, typename T>
class ResizeBilinearOpGrad : public OpKernel {
public:
explicit ResizeBilinearOpGrad(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Validate input.
// First argument is gradient with respect to resized image.
const Tensor& input = context->input(0);
OP_REQUIRES(context, input.dims() == 4,
errors::InvalidArgument("input_grad must be 4-dimensional",
input.shape().ShortDebugString()));
// ResizeBilinear always produces float images, so the input gradient is
// always a float.
OP_REQUIRES(context, input.dtype() == DT_FLOAT,
errors::InvalidArgument("input_grad must be of type float",
input.dtype()));
// The second argument is the original input to resize_bilinear.
const Tensor& original_image = context->input(1);
OP_REQUIRES(
context, original_image.dims() == 4,
errors::InvalidArgument("original_image must be 4-dimensional",
original_image.shape().ShortDebugString()));
// Allocate output and initialize to zeros.
const int64 batch_size = input.dim_size(0);
const int64 channels = input.dim_size(3);
const int64 resized_height = input.dim_size(1);
const int64 resized_width = input.dim_size(2);
const int64 original_height = original_image.dim_size(1);
const int64 original_width = original_image.dim_size(2);
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0, TensorShape({batch_size, original_height,
original_width, channels}),
&output));
typename TTypes<float, 4>::ConstTensor input_grad =
input.tensor<float, 4>();
typename TTypes<T, 4>::Tensor output_grad = output->tensor<T, 4>();
for (int c = 0; c < channels; ++c) {
for (int y = 0; y < original_height; ++y) {
for (int x = 0; x < original_width; ++x) {
for (int b = 0; b < batch_size; ++b) {
output_grad(b, y, x, c) = 0;
}
}
}
}
const float height_scale =
original_height / static_cast<float>(resized_height);
const float width_scale =
original_width / static_cast<float>(resized_width);
// Each resized pixel was computed as a weighted average of four input
// pixels. Here we find the pixels that contributed to each output pixel
// and add the corresponding coefficient to the gradient.
// resized(b, y, x, c) = top_left * (1 - y) * (1 - x)
// + top_right * (1 - y) * x
// + bottom_left * y * (1 - x)
// + bottom_right * y * x
for (int b = 0; b < batch_size; ++b) {
for (int y = 0; y < resized_height; ++y) {
const float in_y = y * height_scale;
const int top_y_index = static_cast<int>(floorf(in_y));
const int bottom_y_index =
std::min(static_cast<int64>(ceilf(in_y)), (original_height - 1));
const float y_lerp = in_y - top_y_index;
const float inverse_y_lerp = (1.0f - y_lerp);
for (int x = 0; x < resized_width; ++x) {
const float in_x = x * width_scale;
const int left_x_index = static_cast<int>(floorf(in_x));
const int right_x_index =
std::min(static_cast<int64>(ceilf(in_x)), (original_width - 1));
const float x_lerp = in_x - left_x_index;
const float inverse_x_lerp = (1.0f - x_lerp);
for (int c = 0; c < channels; ++c) {
output_grad(b, top_y_index, left_x_index, c) +=
input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp;
output_grad(b, top_y_index, right_x_index, c) +=
input_grad(b, y, x, c) * inverse_y_lerp * x_lerp;
output_grad(b, bottom_y_index, left_x_index, c) +=
input_grad(b, y, x, c) * y_lerp * inverse_x_lerp;
output_grad(b, bottom_y_index, right_x_index, c) +=
input_grad(b, y, x, c) * y_lerp * x_lerp;
}
}
}
}
}
};
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("ResizeBilinear") \
.Device(DEVICE_CPU) \
......@@ -121,4 +217,12 @@ REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
REGISTER_KERNEL_BUILDER(Name("ResizeBilinearGrad")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
ResizeBilinearOpGrad<CPUDevice, float>);
REGISTER_KERNEL_BUILDER(Name("ResizeBilinearGrad")
.Device(DEVICE_CPU)
.TypeConstraint<double>("T"),
ResizeBilinearOpGrad<CPUDevice, double>);
} // namespace tensorflow
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include <algorithm>
#include <unordered_map>
#include <utility>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
namespace tensorflow {
using sparse::SparseTensor;
class SerializeSparseOp : public OpKernel {
public:
explicit SerializeSparseOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* input_indices;
const Tensor* input_values;
const Tensor* input_shape;
OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
errors::InvalidArgument(
"Input indices should be a matrix but received shape ",
input_indices->shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
errors::InvalidArgument(
"Input values should be a vector but received shape ",
input_values->shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
errors::InvalidArgument(
"Input shape should be a vector but received shape ",
input_shape->shape().DebugString()));
TensorProto proto_indices;
TensorProto proto_values;
TensorProto proto_shape;
input_indices->AsProtoTensorContent(&proto_indices);
input_values->AsProtoTensorContent(&proto_values);
input_shape->AsProtoTensorContent(&proto_shape);
Tensor serialized_sparse(DT_STRING, TensorShape({3}));
auto serialized_sparse_t = serialized_sparse.vec<string>();
serialized_sparse_t(0) = proto_indices.SerializeAsString();
serialized_sparse_t(1) = proto_values.SerializeAsString();
serialized_sparse_t(2) = proto_shape.SerializeAsString();
context->set_output(0, serialized_sparse);
}
};
REGISTER_KERNEL_BUILDER(Name("SerializeSparse").Device(DEVICE_CPU),
SerializeSparseOp);
template <typename T>
class SerializeManySparseOp : public OpKernel {
public:
explicit SerializeManySparseOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* input_indices;
const Tensor* input_values;
const Tensor* input_shape;
OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
errors::InvalidArgument(
"Input indices should be a matrix but received shape ",
input_indices->shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
errors::InvalidArgument(
"Input values should be a vector but received shape ",
input_values->shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
errors::InvalidArgument(
"Input shape should be a vector but received shape ",
input_shape->shape().DebugString()));
int rank = input_shape->NumElements();
OP_REQUIRES(
context, rank > 1,
errors::InvalidArgument(
"Rank of input SparseTensor should be > 1, but saw rank: ", rank));
TensorShape tensor_input_shape(input_shape->vec<int64>());
gtl::InlinedVector<int64, 8> std_order(rank);
std::iota(std_order.begin(), std_order.end(), 0);
SparseTensor input_st(*input_indices, *input_values, tensor_input_shape,
std_order);
auto input_shape_t = input_shape->vec<int64>();
const int64 N = input_shape_t(0);
Tensor serialized_sparse(DT_STRING, TensorShape({N, 3}));
auto serialized_sparse_t = serialized_sparse.matrix<string>();
OP_REQUIRES(context, input_st.IndicesValid(),
errors::InvalidArgument("Input SparseTensor fails check for "
"lexicographic ordering of indices. "
"Cannot split."));
// We can generate the output shape proto string now, for all
// minibatch entries.
Tensor output_shape(DT_INT64, {rank - 1});
auto output_shape_t = output_shape.vec<int64>();
for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d);
TensorProto proto_shape;
output_shape.AsProtoTensorContent(&proto_shape);
const string proto_shape_string = proto_shape.SerializeAsString();
Tensor output_blank_indices(DT_INT64, {0, rank - 1});
Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
TensorProto proto_blank_indices;
TensorProto proto_blank_values;
output_blank_indices.AsProtoTensorContent(&proto_blank_indices);
output_blank_values.AsProtoTensorContent(&proto_blank_values);
const string proto_blank_indices_string =
proto_blank_indices.SerializeAsString();
const string proto_blank_values_string =
proto_blank_values.SerializeAsString();
// Initialize output with empty values and the proper shapes.
serialized_sparse_t.chip<1>(0).setConstant(proto_blank_indices_string);
serialized_sparse_t.chip<1>(1).setConstant(proto_blank_values_string);
serialized_sparse_t.chip<1>(2).setConstant(proto_shape_string);
// Get groups by minibatch dimension
sparse::GroupIterable minibatch = input_st.group({0});
for (const auto& subset : minibatch) {
const int64 b = subset.group()[0];
OP_REQUIRES(
context, b > -1 && b < N,
errors::InvalidArgument(
"Received unexpected column 0 value in input SparseTensor: ", b,
" < 0 or >= N (= ", N, ")"));
const auto indices = subset.indices();
const auto values = subset.values<T>();
const int64 num_entries = values.size();
Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
auto output_indices_t = output_indices.matrix<int64>();
auto output_values_t = output_values.vec<T>();
for (int i = 0; i < num_entries; ++i) {
for (int d = 1; d < rank; ++d) {
output_indices_t(i, d - 1) = indices(i, d);
}
output_values_t(i) = values(i);
}
TensorProto proto_indices;
TensorProto proto_values;
output_indices.AsProtoTensorContent(&proto_indices);
output_values.AsProtoTensorContent(&proto_values);
serialized_sparse_t(b, 0) = proto_indices.SerializeAsString();
serialized_sparse_t(b, 1) = proto_values.SerializeAsString();
}
context->set_output(0, serialized_sparse);
}
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
SerializeManySparseOp<type>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <typename T>
class DeserializeManySparseOp : public OpKernel {
public:
explicit DeserializeManySparseOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& serialized_sparse = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(serialized_sparse.shape()),
errors::InvalidArgument(
"Serialized sparse should be a matrix but received shape ",
serialized_sparse.shape().DebugString()));
OP_REQUIRES(
context, serialized_sparse.shape().dim_size(1) == 3,
errors::InvalidArgument(
"Serialize sparse should have 3 columns but received shape ",
serialized_sparse.shape().DebugString()));
int num_sparse_tensors = serialized_sparse.shape().dim_size(0);
OP_REQUIRES(
context, num_sparse_tensors > 0,
errors::InvalidArgument("Must have at least 1 serialized SparseTensor, "
"but input matrix has 0 rows"));
std::vector<Tensor> indices_to_concat;
std::vector<Tensor> values_to_concat;
std::vector<TensorShape> shapes_to_concat;
const auto& serialized_sparse_t = serialized_sparse.matrix<string>();
for (int i = 0; i < num_sparse_tensors; ++i) {
Tensor output_indices(DT_INT64);
Tensor output_values(DataTypeToEnum<T>::value);
Tensor output_shape(DT_INT64);
TensorProto proto_indices;
TensorProto proto_values;
TensorProto proto_shape;
OP_REQUIRES(context, ParseProtoUnlimited(&proto_indices,
serialized_sparse_t(i, 0)),
errors::InvalidArgument("Could not parse serialized_sparse[",
i, ", 0]"));
OP_REQUIRES(context,
ParseProtoUnlimited(&proto_values, serialized_sparse_t(i, 1)),
errors::InvalidArgument("Could not parse serialized_sparse[",
i, ", 1]"));
OP_REQUIRES(context,
ParseProtoUnlimited(&proto_shape, serialized_sparse_t(i, 2)),
errors::InvalidArgument("Could not parse serialized_sparse[",
i, ", 2]"));
OP_REQUIRES(context, output_indices.FromProto(proto_indices),
errors::InvalidArgument(
"Could not construct Tensor serialized_sparse[", i,
", 0] (indices)"));
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
errors::InvalidArgument(
"Expected serialized_sparse[", i,
", 1] to represent an index matrix but received shape ",
output_indices.shape().DebugString()));
OP_REQUIRES(context, output_values.FromProto(proto_values),
errors::InvalidArgument(
"Could not construct Tensor serialized_sparse[", i,
", 1] (values)"));
OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
errors::InvalidArgument(
"Expected serialized_sparse[", i,
", 1] to represent a values vector but received shape ",
output_values.shape().DebugString()));
OP_REQUIRES(context, output_shape.FromProto(proto_shape),
errors::InvalidArgument(
"Could not construct Tensor serialized_sparse[", i,
", 2] (shape)"));
OP_REQUIRES(
context, TensorShapeUtils::IsVector(output_shape.shape()),
errors::InvalidArgument("Expected serialized_sparse[", i,
", 1] to be a shape vector but its shape is ",
output_shape.shape().DebugString()));
OP_REQUIRES(
context, DataTypeToEnum<T>::value == output_values.dtype(),
errors::InvalidArgument(
"Requested SparseTensor of type ",
DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i,
"].values.dtype() == ", DataTypeString(output_values.dtype())));
int64 num_entries = output_indices.dim_size(0);
OP_REQUIRES(context, num_entries == output_values.dim_size(0),
errors::InvalidArgument(
"Expected row counts of SparseTensor[", i,
"].indices and SparseTensor[", i,
"].values to match but they do not: ", num_entries,
" vs. ", output_values.dim_size(0)));
int rank = output_indices.dim_size(1);
OP_REQUIRES(
context, rank == output_shape.dim_size(0),
errors::InvalidArgument("Expected column counts of SparseTensor[", i,
"].indices to match size of SparseTensor[", i,
"].shape "
"but they do not: ",
rank, " vs. ", output_shape.dim_size(0)));
// Now we expand each SparseTensors' indices and shape by
// prefixing a dimension
Tensor expanded_indices(
DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)}));
Tensor expanded_shape(DT_INT64,
TensorShape({1 + output_shape.dim_size(0)}));
const auto& output_indices_t = output_indices.matrix<int64>();
const auto& output_shape_t = output_shape.vec<int64>();
auto expanded_indices_t = expanded_indices.matrix<int64>();
auto expanded_shape_t = expanded_shape.vec<int64>();
expanded_indices_t.chip<1>(0).setZero();
Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
expanded_shape_t(0) = 1;
std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
TensorShape expanded_tensor_shape(expanded_shape.vec<int64>());
indices_to_concat.push_back(expanded_indices);
values_to_concat.push_back(output_values);
shapes_to_concat.push_back(expanded_tensor_shape);
}
int rank = -1;
for (int i = 0; i < num_sparse_tensors; ++i) {
if (rank < 0) rank = shapes_to_concat[i].dims();
OP_REQUIRES(context, rank == shapes_to_concat[i].dims(),
errors::InvalidArgument(
"Inconsistent rank across SparseTensors: rank prior to "
"SparseTensor[",
i, "] was: ", rank, " but rank of SparseTensor[", i,
"] is: ", shapes_to_concat[i].dims()));
}
// SparseTensor::Concat requires consistent shape for all but the
// primary order dimension (dimension 0 in this case). So we get
// the maximum value across all the input SparseTensors for each
// dimension and use that.
TensorShape preconcat_shape(shapes_to_concat[0]);
for (int i = 0; i < num_sparse_tensors; ++i) {
for (int d = 0; d < rank; ++d) {
preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d),
shapes_to_concat[i].dim_size(d)));
}
}
// Dimension 0 is the primary dimension.
gtl::InlinedVector<int64, 8> std_order(rank);
std::iota(std_order.begin(), std_order.end(), 0);
std::vector<SparseTensor> tensors_to_concat;
for (int i = 0; i < num_sparse_tensors; ++i) {
tensors_to_concat.emplace_back(indices_to_concat[i], values_to_concat[i],
preconcat_shape, std_order);
}
SparseTensor output = SparseTensor::Concat<T>(tensors_to_concat);
Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));
std::copy_n(output.shape().dim_sizes().data(), output.dims(),
final_output_shape.vec<int64>().data());
context->set_output(0, output.indices());
context->set_output(1, output.values());
context->set_output(2, final_output_shape);
}
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
DeserializeManySparseOp<type>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
} // namespace tensorflow
......@@ -28,6 +28,8 @@ namespace tensorflow {
class SummaryImageOp : public OpKernel {
public:
typedef Eigen::Tensor<uint8, 2, Eigen::RowMajor> Uint8Image;
explicit SummaryImageOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("max_images", &max_images_));
const TensorProto* proto;
......@@ -61,22 +63,56 @@ class SummaryImageOp : public OpKernel {
const int w = tensor.dim_size(2);
const int hw = h * w; // Compact these two dims for simplicity
const int depth = tensor.dim_size(3);
auto tensor_eigen = tensor.shaped<float, 3>({batch_size, hw, depth});
OP_REQUIRES(c, bad_color_.dim_size(0) >= depth,
errors::InvalidArgument(
"expected depth <= bad_color.size, got depth = ", depth,
", bad_color.size = ", bad_color_.dim_size(0)));
auto bad_color_full = bad_color_.vec<uint8>();
typename TTypes<uint8>::Vec bad_color(bad_color_full.data(), depth);
Summary s;
if (tensor.dtype() == DT_UINT8) {
// For uint8 input, no normalization is necessary
auto ith_image = [&tensor, batch_size, hw, depth](int i) {
auto values = tensor.shaped<uint8, 3>({batch_size, hw, depth});
return typename TTypes<uint8>::ConstMatrix(
&values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
};
AddImages(base_tag, batch_size, w, h, depth, ith_image, &s);
} else { // tensor.dtype() == DT_FLOAT
// For float images, nans and infs are replaced with bad_color.
OP_REQUIRES(c, bad_color_.dim_size(0) >= depth,
errors::InvalidArgument(
"expected depth <= bad_color.size, got depth = ", depth,
", bad_color.size = ", bad_color_.dim_size(0)));
auto bad_color_full = bad_color_.vec<uint8>();
typename TTypes<uint8>::ConstVec bad_color(bad_color_full.data(), depth);
// Float images must be scaled and translated.
Uint8Image image(hw, depth);
auto ith_image = [&tensor, &image, bad_color, batch_size, hw,
depth](int i) {
auto tensor_eigen = tensor.shaped<float, 3>({batch_size, hw, depth});
typename TTypes<float>::ConstMatrix values(
&tensor_eigen(i, 0, 0),
Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
NormalizeFloatImage(hw, depth, values, bad_color, &image);
return image;
};
AddImages(base_tag, batch_size, w, h, depth, ith_image, &s);
}
// RGB (or gray or RGBA) is last dimension
Eigen::Tensor<uint8, 2, Eigen::RowMajor> image(hw, depth);
Tensor* summary_tensor = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
}
Summary s;
// Add the sequence of images specified by ith_image to the summary.
//
// Factoring this loop out into a helper function lets ith_image behave
// differently in the float and uint8 cases: the float case needs a temporary
// buffer which can be shared across calls to ith_image, but the uint8 case
// does not.
Status AddImages(const string& tag, int batch_size, int w, int h, int depth,
const std::function<Uint8Image(int)>& ith_image,
Summary* s) {
const int N = std::min<int>(max_images_, batch_size);
for (int i = 0; i < N; ++i) {
Summary::Value* v = s.add_value();
Summary::Value* v = s->add_value();
// The tag depends on the number of requested images (not the number
// produced.)
//
......@@ -84,93 +120,94 @@ class SummaryImageOp : public OpKernel {
// convention for display, so we append "/image" to guarantee that the
// image(s) won't be displayed in the global scope with no name.
if (max_images_ > 1) {
v->set_tag(strings::StrCat(base_tag, "/image/", i));
v->set_tag(strings::StrCat(tag, "/image/", i));
} else {
v->set_tag(strings::StrCat(base_tag, "/image"));
v->set_tag(strings::StrCat(tag, "/image"));
}
if (image.size()) {
typename TTypes<float>::ConstMatrix values(
&tensor_eigen(i, 0, 0),
Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
// Rescale the image to uint8 range.
//
// We are trying to generate an RCG image from a float tensor. We do
// not have any info about the expected range of values in the tensor
// but the generated image needs to have all RGB values within [0, 255].
//
// We use two different algorithms to generate these values. If the
// tensor has only positive values we scale them all by 255/max(values).
// If the tensor has both negative and positive values we scale them by
// the max of their absolute values and center them around 127.
//
// This works for most cases, but has the incovenient of not respecting
// the relative dynamic range across different instances of the tensor.
// Compute min and max ignoring nonfinite pixels
float image_min = std::numeric_limits<float>::infinity();
float image_max = -image_min;
for (int i = 0; i < hw; i++) {
bool finite = true;
for (int j = 0; j < depth; j++) {
if (!std::isfinite(values(i, j))) {
finite = false;
break;
}
}
if (finite) {
for (int j = 0; j < depth; j++) {
float value = values(i, j);
image_min = std::min(image_min, value);
image_max = std::max(image_max, value);
}
}
}
auto image = ith_image(i);
Summary::Image* si = v->mutable_image();
si->set_height(h);
si->set_width(w);
si->set_colorspace(depth);
const int channel_bits = 8;
const int compression = -1; // Use zlib default
if (!png::WriteImageToBuffer(
image.data(), w, h, w * depth, depth, channel_bits, compression,
si->mutable_encoded_image_string(), nullptr)) {
return errors::Internal("PNG encoding failed");
}
}
return Status::OK();
}
// Pick an affine transform into uint8
const float kZeroThreshold = 1e-6;
float scale, offset;
if (image_min < 0) {
float max_val = std::max(std::abs(image_min), std::abs(image_max));
scale = max_val < kZeroThreshold ? 0.0f : 127.0f / max_val;
offset = 128.0f;
} else {
scale = image_max < kZeroThreshold ? 0.0f : 255.0f / image_max;
offset = 0.0f;
static void NormalizeFloatImage(int hw, int depth,
typename TTypes<float>::ConstMatrix values,
typename TTypes<uint8>::ConstVec bad_color,
Uint8Image* image) {
if (!image->size()) return; // Nothing to do for empty images
// Rescale the image to uint8 range.
//
// We are trying to generate an RGB image from a float tensor. We do
// not have any info about the expected range of values in the tensor
// but the generated image needs to have all RGB values within [0, 255].
//
// We use two different algorithms to generate these values. If the
// tensor has only positive values we scale them all by 255/max(values).
// If the tensor has both negative and positive values we scale them by
// the max of their absolute values and center them around 127.
//
// This works for most cases, but does not respect the relative dynamic
// range across different instances of the tensor.
// Compute min and max ignoring nonfinite pixels
float image_min = std::numeric_limits<float>::infinity();
float image_max = -image_min;
for (int i = 0; i < hw; i++) {
bool finite = true;
for (int j = 0; j < depth; j++) {
if (!std::isfinite(values(i, j))) {
finite = false;
break;
}
// Transform image, turning nonfinite values to bad_color
for (int i = 0; i < hw; i++) {
bool finite = true;
for (int j = 0; j < depth; j++) {
if (!std::isfinite(values(i, j))) {
finite = false;
break;
}
}
if (finite) {
image.chip<0>(i) =
(values.chip<0>(i) * scale + offset).cast<uint8>();
} else {
image.chip<0>(i) = bad_color;
}
}
if (finite) {
for (int j = 0; j < depth; j++) {
float value = values(i, j);
image_min = std::min(image_min, value);
image_max = std::max(image_max, value);
}
}
}
Summary::Image* si = v->mutable_image();
si->set_height(h);
si->set_width(w);
si->set_colorspace(depth);
OP_REQUIRES(c, png::WriteImageToBuffer(
image.data(), w, h, w * depth, depth, 8, -1,
si->mutable_encoded_image_string(), nullptr),
errors::Internal("PNG encoding failed"));
// Pick an affine transform into uint8
const float kZeroThreshold = 1e-6;
float scale, offset;
if (image_min < 0) {
float max_val = std::max(std::abs(image_min), std::abs(image_max));
scale = max_val < kZeroThreshold ? 0.0f : 127.0f / max_val;
offset = 128.0f;
} else {
scale = image_max < kZeroThreshold ? 0.0f : 255.0f / image_max;
offset = 0.0f;
}
Tensor* summary_tensor = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
// Transform image, turning nonfinite values to bad_color
for (int i = 0; i < hw; i++) {
bool finite = true;
for (int j = 0; j < depth; j++) {
if (!std::isfinite(values(i, j))) {
finite = false;
break;
}
}
if (finite) {
image->chip<0>(i) = (values.chip<0>(i) * scale + offset).cast<uint8>();
} else {
image->chip<0>(i) = bad_color;
}
}
}
private:
......
......@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/errors.h"
......@@ -105,14 +106,12 @@ class SummaryHistoOp : public OpKernel {
}
};
REGISTER_KERNEL_BUILDER(Name("HistogramSummary")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
SummaryHistoOp<float>);
REGISTER_KERNEL_BUILDER(Name("HistogramSummary")
.Device(DEVICE_CPU)
.TypeConstraint<double>("T"),
SummaryHistoOp<double>);
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER( \
Name("HistogramSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SummaryHistoOp<T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER)
#undef REGISTER
struct HistogramResource : public ResourceBase {
histogram::ThreadSafeHistogram histogram;
......
......@@ -71,6 +71,23 @@ resized_images: 4-D with shape
`[batch, new_height, new_width, channels]`.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("ResizeBilinearGrad")
.Input("grads: float")
.Input("original_image: T")
.Output("output: T")
.Attr("T: {float, double}")
.Doc(R"doc(
Computes the gradient of bilinear interpolation.
grads: 4-D with shape `[batch, height, width, channels]`.
original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`,
The image tensor that was resized.
output: 4-D with shape `[batch, orig_height, orig_width, channels]`.
Gradients with respect to the input image. Input image must have been
float or double.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("ResizeNearestNeighbor")
.Input("images: T")
......@@ -80,8 +97,6 @@ REGISTER_OP("ResizeNearestNeighbor")
.Doc(R"doc(
Resize `images` to `size` using nearest neighbor interpolation.
Input images can be of different types but output images are always float.
images: 4-D with shape `[batch, height, width, channels]`.
size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
......@@ -301,4 +316,40 @@ compression: Compression level.
contents: 0-D. PNG-encoded image.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("RGBToHSV")
.Input("images: float")
.Output("output: float")
.Doc(R"doc(
Converts one or more images from RGB to HSV.
Outputs a tensor of the same shape as the `images` tensor, containing the HSV
value of the pixels. The output is only well defined if the value in `images`
are in `[0,1]`.
`output[..., 0]` contains hue, `output[..., 1]` contains saturation, and
`output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0
corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue.
images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3.
output: `images` converted to HSV.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("HSVToRGB")
.Input("images: float")
.Output("output: float")
.Doc(R"doc(
Convert one or more images from HSV to RGB.
Outputs a tensor of the same shape as the `images` tensor, containing the RGB
value of the pixels. The output is only well defined if the value in `images`
are in `[0,1]`.
See `rgb_to_hsv` for a description of the HSV encoding.
images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3.
output: `images` converted to RGB.
)doc");
} // namespace tensorflow
......@@ -1929,6 +1929,33 @@ op {
}
summary: "Reinterpret the bytes of a string as a vector of numbers."
}
op {
name: "DeserializeManySparse"
input_arg {
name: "serialized_sparse"
description: "2-D, The `N` serialized `SparseTensor` objects.\nMust have 3 columns."
type: DT_STRING
}
output_arg {
name: "sparse_indices"
type: DT_INT64
}
output_arg {
name: "sparse_values"
type_attr: "dtype"
}
output_arg {
name: "sparse_shape"
type: DT_INT64
}
attr {
name: "dtype"
type: "type"
description: "The `dtype` of the serialized `SparseTensor` objects."
}
summary: "Deserialize and concatenate `SparseTensors` from a serialized minibatch."
description: "The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where\n`N` is the minibatch size and the rows correspond to packed outputs of\n`SerializeSparse`. The ranks of the original `SparseTensor` objects\nmust all match. When the final `SparseTensor` is created, it has rank one\nhigher than the ranks of the incoming `SparseTensor` objects\n(they have been concatenated along a new row dimension).\n\nThe output `SparseTensor` object\'s shape values for all dimensions but the\nfirst are the max across the input `SparseTensor` objects\' shape values\nfor the corresponding dimensions. Its first shape value is `N`, the minibatch\nsize.\n\nThe input `SparseTensor` objects\' indices are assumed ordered in\nstandard lexicographic order. If this is not the case, after this\nstep run `SparseReorder` to restore index ordering.\n\nFor example, if the serialized input is a `[2 x 3]` matrix representing two\noriginal `SparseTensor` objects:\n\n index = [ 0]\n [10]\n [20]\n values = [1, 2, 3]\n shape = [50]\n\nand\n\n index = [ 2]\n [10]\n values = [4, 5]\n shape = [30]\n\nthen the final deserialized `SparseTensor` will be:\n\n index = [0 0]\n [0 10]\n [0 20]\n [1 2]\n [1 10]\n values = [1, 2, 3, 4, 5]\n shape = [2 50]"
}
op {
name: "DestroyTemporaryVariable"
input_arg {
......@@ -2842,6 +2869,21 @@ op {
}
summary: "Returns the truth value of (x >= y) element-wise."
}
op {
name: "HSVToRGB"
input_arg {
name: "images"
description: "1-D or higher rank. HSV data to convert. Last dimension must be size 3."
type: DT_FLOAT
}
output_arg {
name: "output"
description: "`images` converted to RGB."
type: DT_FLOAT
}
summary: "Convert one or more images from HSV to RGB."
description: "Outputs a tensor of the same shape as the `images` tensor, containing the RGB\nvalue of the pixels. The output is only well defined if the value in `images`\nare in `[0,1]`.\n\nSee `rgb_to_hsv` for a description of the HSV encoding."
}
op {
name: "HashTable"
output_arg {
......@@ -2906,6 +2948,11 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_UINT8
type: DT_INT16
type: DT_INT8
}
}
}
......@@ -2979,7 +3026,7 @@ op {
input_arg {
name: "tensor"
description: "4-D of shape `[batch_size, height, width, channels]` where\n`channels` is 1, 3, or 4."
type: DT_FLOAT
type_attr: "T"
}
output_arg {
name: "summary"
......@@ -2996,6 +3043,19 @@ op {
has_minimum: true
minimum: 1
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_UINT8
type: DT_FLOAT
}
}
}
attr {
name: "bad_color"
type: "tensor"
......@@ -3016,7 +3076,7 @@ op {
description: "Color to use for pixels with non-finite values."
}
summary: "Outputs a `Summary` protocol buffer with images."
description: "The summary has up to `max_images` summary values containing images. The\nimages are built from `tensor` which must be 4-D with shape `[batch_size,\nheight, width, channels]` and where `channels` can be:\n\n* 1: `tensor` is interpreted as Grayscale.\n* 3: `tensor` is interpreted as RGB.\n* 4: `tensor` is interpreted as RGBA.\n\nThe images have the same number of channels as the input tensor. Their values\nare normalized, one image at a time, to fit in the range `[0, 255]`. The\nop uses two different normalization algorithms:\n\n* If the input values are all positive, they are rescaled so the largest one\n is 255.\n\n* If any input value is negative, the values are shifted so input value 0.0\n is at 127. They are then rescaled so that either the smallest value is 0,\n or the largest one is 255.\n\nThe `tag` argument is a scalar `Tensor` of type `string`. It is used to\nbuild the `tag` of the summary values:\n\n* If `max_images` is 1, the summary value tag is \'*tag*/image\'.\n* If `max_images` is greater than 1, the summary value tags are\n generated sequentially as \'*tag*/image/0\', \'*tag*/image/1\', etc.\n\nThe `bad_color` argument is the color to use in the generated images for\nnon-finite input values. It is a `unit8` 1-D tensor of length `channels`.\nEach element must be in the range `[0, 255]` (It represents the value of a\npixel in the output image). Non-finite values in the input tensor are\nreplaced by this tensor in the output image. The default value is the color\nred."
description: "The summary has up to `max_images` summary values containing images. The\nimages are built from `tensor` which must be 4-D with shape `[batch_size,\nheight, width, channels]` and where `channels` can be:\n\n* 1: `tensor` is interpreted as Grayscale.\n* 3: `tensor` is interpreted as RGB.\n* 4: `tensor` is interpreted as RGBA.\n\nThe images have the same number of channels as the input tensor. For float\ninput, the values are normalized one image at a time to fit in the range\n`[0, 255]`. `uint8` values are unchanged. The op uses two different\nnormalization algorithms:\n\n* If the input values are all positive, they are rescaled so the largest one\n is 255.\n\n* If any input value is negative, the values are shifted so input value 0.0\n is at 127. They are then rescaled so that either the smallest value is 0,\n or the largest one is 255.\n\nThe `tag` argument is a scalar `Tensor` of type `string`. It is used to\nbuild the `tag` of the summary values:\n\n* If `max_images` is 1, the summary value tag is \'*tag*/image\'.\n* If `max_images` is greater than 1, the summary value tags are\n generated sequentially as \'*tag*/image/0\', \'*tag*/image/1\', etc.\n\nThe `bad_color` argument is the color to use in the generated images for\nnon-finite input values. It is a `unit8` 1-D tensor of length `channels`.\nEach element must be in the range `[0, 255]` (It represents the value of a\npixel in the output image). Non-finite values in the input tensor are\nreplaced by this tensor in the output image. The default value is the color\nred."
}
op {
name: "InTopK"
......@@ -5043,6 +5103,21 @@ op {
}
summary: "Computes the number of elements in the given queue."
}
op {
name: "RGBToHSV"
input_arg {
name: "images"
description: "1-D or higher rank. RGB data to convert. Last dimension must be size 3."
type: DT_FLOAT
}
output_arg {
name: "output"
description: "`images` converted to HSV."
type: DT_FLOAT
}
summary: "Converts one or more images from RGB to HSV."
description: "Outputs a tensor of the same shape as the `images` tensor, containing the HSV\nvalue of the pixels. The output is only well defined if the value in `images`\nare in `[0,1]`.\n\n`output[..., 0]` contains hue, `output[..., 1]` contains saturation, and\n`output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0\ncorresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue."
}
op {
name: "RandomCrop"
input_arg {
......@@ -5906,6 +5981,35 @@ op {
summary: "Resize `images` to `size` using bilinear interpolation."
description: "Input images can be of different types but output images are always float."
}
op {
name: "ResizeBilinearGrad"
input_arg {
name: "grads"
description: "4-D with shape `[batch, height, width, channels]`."
type: DT_FLOAT
}
input_arg {
name: "original_image"
description: "4-D with shape `[batch, orig_height, orig_width, channels]`,\nThe image tensor that was resized."
type_attr: "T"
}
output_arg {
name: "output"
description: "4-D with shape `[batch, orig_height, orig_width, channels]`.\nGradients with respect to the input image. Input image must have been\nfloat or double."
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Computes the gradient of bilinear interpolation."
}
op {
name: "ResizeNearestNeighbor"
input_arg {
......@@ -5937,7 +6041,6 @@ op {
}
}
summary: "Resize `images` to `size` using nearest neighbor interpolation."
description: "Input images can be of different types but output images are always float."
}
op {
name: "ResizeNearestNeighborGrad"
......@@ -6701,6 +6804,61 @@ op {
summary: "Calculates the Eigen Decomposition of a square Self-Adjoint matrix."
description: "Only the lower-triangular part of the input will be used in this case. The\nupper-triangular part will not be read.\n\nThe result is a M+1 x M matrix whose first row is the eigenvalues, and\nsubsequent rows are eigenvectors."
}
op {
name: "SerializeManySparse"
input_arg {
name: "sparse_indices"
description: "2-D. The `indices` of the minibatch `SparseTensor`."
type: DT_INT64
}
input_arg {
name: "sparse_values"
description: "1-D. The `values` of the minibatch `SparseTensor`."
type_attr: "T"
}
input_arg {
name: "sparse_shape"
description: "1-D. The `shape` of the minibatch `SparseTensor`."
type: DT_INT64
}
output_arg {
name: "serialized_sparse"
type: DT_STRING
}
attr {
name: "T"
type: "type"
}
summary: "Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`."
description: "The `SparseTensor` must have rank `R` greater than 1, and the first dimension\nis treated as the minibatch dimension. Elements of the `SparseTensor`\nmust be sorted in increasing order of this first dimension. The serialized\n`SparseTensor` objects going into each row of `serialized_sparse` will have\nrank `R-1`.\n\nThe minibatch size `N` is extracted from `sparse_shape[0]`."
}
op {
name: "SerializeSparse"
input_arg {
name: "sparse_indices"
description: "2-D. The `indices` of the `SparseTensor`."
type: DT_INT64
}
input_arg {
name: "sparse_values"
description: "1-D. The `values` of the `SparseTensor`."
type_attr: "T"
}
input_arg {
name: "sparse_shape"
description: "1-D. The `shape` of the `SparseTensor`."
type: DT_INT64
}
output_arg {
name: "serialized_sparse"
type: DT_STRING
}
attr {
name: "T"
type: "type"
}
summary: "Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object."
}
op {
name: "Shape"
input_arg {
......
......@@ -17,6 +17,98 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("SerializeSparse")
.Input("sparse_indices: int64")
.Input("sparse_values: T")
.Input("sparse_shape: int64")
.Attr("T: type")
.Output("serialized_sparse: string")
.Doc(R"doc(
Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object.
sparse_indices: 2-D. The `indices` of the `SparseTensor`.
sparse_values: 1-D. The `values` of the `SparseTensor`.
sparse_shape: 1-D. The `shape` of the `SparseTensor`.
)doc");
REGISTER_OP("SerializeManySparse")
.Input("sparse_indices: int64")
.Input("sparse_values: T")
.Input("sparse_shape: int64")
.Attr("T: type")
.Output("serialized_sparse: string")
.Doc(R"doc(
Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`.
The `SparseTensor` must have rank `R` greater than 1, and the first dimension
is treated as the minibatch dimension. Elements of the `SparseTensor`
must be sorted in increasing order of this first dimension. The serialized
`SparseTensor` objects going into each row of `serialized_sparse` will have
rank `R-1`.
The minibatch size `N` is extracted from `sparse_shape[0]`.
sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`.
sparse_values: 1-D. The `values` of the minibatch `SparseTensor`.
sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`.
)doc");
REGISTER_OP("DeserializeManySparse")
.Input("serialized_sparse: string")
.Attr("dtype: type")
.Output("sparse_indices: int64")
.Output("sparse_values: dtype")
.Output("sparse_shape: int64")
.Doc(R"doc(
Deserialize and concatenate `SparseTensors` from a serialized minibatch.
The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where
`N` is the minibatch size and the rows correspond to packed outputs of
`SerializeSparse`. The ranks of the original `SparseTensor` objects
must all match. When the final `SparseTensor` is created, it has rank one
higher than the ranks of the incoming `SparseTensor` objects
(they have been concatenated along a new row dimension).
The output `SparseTensor` object's shape values for all dimensions but the
first are the max across the input `SparseTensor` objects' shape values
for the corresponding dimensions. Its first shape value is `N`, the minibatch
size.
The input `SparseTensor` objects' indices are assumed ordered in
standard lexicographic order. If this is not the case, after this
step run `SparseReorder` to restore index ordering.
For example, if the serialized input is a `[2 x 3]` matrix representing two
original `SparseTensor` objects:
index = [ 0]
[10]
[20]
values = [1, 2, 3]
shape = [50]
and
index = [ 2]
[10]
values = [4, 5]
shape = [30]
then the final deserialized `SparseTensor` will be:
index = [0 0]
[0 10]
[0 20]
[1 2]
[1 10]
values = [1, 2, 3, 4, 5]
shape = [2 50]
serialized_sparse: 2-D, The `N` serialized `SparseTensor` objects.
Must have 3 columns.
dtype: The `dtype` of the serialized `SparseTensor` objects.
)doc");
REGISTER_OP("SparseToDense")
.Input("sparse_indices: Tindices")
.Input("output_shape: Tindices")
......
......@@ -40,7 +40,7 @@ REGISTER_OP("HistogramSummary")
.Input("tag: string")
.Input("values: T")
.Output("summary: string")
.Attr("T: {float, double} = DT_FLOAT")
.Attr("T: realnumbertype = DT_FLOAT")
.Doc(R"doc(
Outputs a `Summary` protocol buffer with a histogram.
......@@ -57,9 +57,10 @@ summary: Scalar. Serialized `Summary` protocol buffer.
REGISTER_OP("ImageSummary")
.Input("tag: string")
.Input("tensor: float")
.Input("tensor: T")
.Output("summary: string")
.Attr("max_images: int >= 1 = 3")
.Attr("T: {uint8, float} = DT_FLOAT")
.Attr(
"bad_color: tensor = { dtype: DT_UINT8 "
"tensor_shape: { dim { size: 4 } } "
......@@ -75,9 +76,10 @@ height, width, channels]` and where `channels` can be:
* 3: `tensor` is interpreted as RGB.
* 4: `tensor` is interpreted as RGBA.
The images have the same number of channels as the input tensor. Their values
are normalized, one image at a time, to fit in the range `[0, 255]`. The
op uses two different normalization algorithms:
The images have the same number of channels as the input tensor. For float
input, the values are normalized one image at a time to fit in the range
`[0, 255]`. `uint8` values are unchanged. The op uses two different
normalization algorithms:
* If the input values are all positive, they are rescaled so the largest one
is 255.
......
......@@ -59,6 +59,8 @@ class SparseTensor {
std::size_t num_entries() const { return ix_.dim_size(0); }
int dims() const { return shape_.dims(); }
const Tensor& indices() const { return ix_; }
const Tensor& values() const { return vals_; }
......
......@@ -1203,6 +1203,13 @@ Returns `True` if this `DType` represents a reference type.
Returns a reference `DType` based on this `DType`.
- - -
#### `tf.DType.is_floating` {#DType.is_floating}
Returns whether this is a (real) floating point type.
- - -
#### `tf.DType.is_integer` {#DType.is_integer}
......@@ -1217,6 +1224,20 @@ Returns whether this is a (non-quantized) integer type.
Returns whether this is a quantized data type.
- - -
#### `tf.DType.is_unsigned` {#DType.is_unsigned}
Returns whether this type is unsigned.
Non-numeric, unordered, and quantized types are not considered unsigned, and
this function returns `False`.
##### Returns:
Whether a `DType` is unsigned.
- - -
......@@ -1255,13 +1276,6 @@ construct a `DataType` object directly. Instead, use the
* <b>`TypeError`</b>: If `type_enum` is not a value `types_pb2.DataType`.
- - -
#### `tf.DType.is_floating` {#DType.is_floating}
Returns whether this is a (real) floating point type.
- - -
#### `tf.DType.max` {#DType.max}
......
......@@ -660,10 +660,128 @@ See also `transpose()`.
## Converting Between Colorspaces.
Image ops work either on individual images or on batches of images, depending on
the shape of their input Tensor.
If 3-D, the shape is `[height, width, channels]`, and the Tensor represents one
image. If 4-D, the shape is `[batch_size, height, width, channels]`, and the
Tensor represents `batch_size` images.
Currently, `channels` can usefully be 1, 2, 3, or 4. Single-channel images are
grayscale, images with 3 channels are encoded as either RGB or HSV. Images
with 2 or 4 channels include an alpha channel, which has to be stripped from the
image before passing the image to most image processing functions (and can be
re-attached later).
Internally, images are either stored in as one `float32` per channel per pixel
(implicitly, values are assumed to lie in `[0,1)`) or one `uint8` per channel
per pixel (values are assumed to lie in `[0,255]`).
Tensorflow can convert between images in RGB or HSV. The conversion functions
work only on float images, so you need to convert images in other formats using
[`convert_image_dtype`](#convert-image-dtype).
Example:
```python
# Decode an image and convert it to HSV.
rgb_image = tf.decode_png(..., channels=3)
rgb_image_float = tf.convert_image_dtype(rgb_image, tf.float32)
hsv_image = tf.hsv_to_rgb(rgb_image)
```
- - -
### `tf.image.rgb_to_grayscale(images)` {#rgb_to_grayscale}
Converts one or more images from RGB to Grayscale.
Outputs a tensor of the same `DType` and rank as `images`. The size of the
last dimension of the output is 1, containing the Grayscale value of the
pixels.
##### Args:
* <b>`images`</b>: The RGB tensor to convert. Last dimension must have size 3 and
should contain RGB values.
##### Returns:
The converted grayscale image(s).
- - -
### `tf.image.grayscale_to_rgb(images)` {#grayscale_to_rgb}
Converts one or more images from Grayscale to RGB.
Outputs a tensor of the same `DType` and rank as `images`. The size of the
last dimension of the output is 3, containing the RGB value of the pixels.
##### Args:
* <b>`images`</b>: The Grayscale tensor to convert. Last dimension must be size 1.
##### Returns:
The converted grayscale image(s).
- - -
### `tf.image.hsv_to_rgb(images, name=None)` {#hsv_to_rgb}
Convert one or more images from HSV to RGB.
Outputs a tensor of the same shape as the `images` tensor, containing the RGB
value of the pixels. The output is only well defined if the value in `images`
are in `[0,1]`.
See `rgb_to_hsv` for a description of the HSV encoding.
##### Args:
* <b>`images`</b>: A `Tensor` of type `float32`.
1-D or higher rank. HSV data to convert. Last dimension must be size 3.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A `Tensor` of type `float32`. `images` converted to RGB.
- - -
### `tf.image.rgb_to_hsv(images, name=None)` {#rgb_to_hsv}
Converts one or more images from RGB to HSV.
Outputs a tensor of the same shape as the `images` tensor, containing the HSV
value of the pixels. The output is only well defined if the value in `images`
are in `[0,1]`.
`output[..., 0]` contains hue, `output[..., 1]` contains saturation, and
`output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0
corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue.
##### Args:
* <b>`images`</b>: A `Tensor` of type `float32`.
1-D or higher rank. RGB data to convert. Last dimension must be size 3.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A `Tensor` of type `float32`. `images` converted to HSV.
- - -
### `tf.image.convert_image_dtype(image, dtype, name=None)` {#convert_image_dtype}
......@@ -699,9 +817,13 @@ overflow errors when converted to integer `Dtype`s.
TensorFlow provides functions to adjust images in various ways: brightness,
contrast, hue, and saturation. Each adjustment can be done with predefined
parameters or with random parameters picked from predefined intervals. Random
parameters or with random parameters picked from predefined intervals. Random
adjustments are often useful to expand a training set and reduce overfitting.
If several adjustments are chained it is advisable to minimize the number of
redundant conversions by first converting the images to the most natural data
type and representation (RGB or HSV).
- - -
### `tf.image.adjust_brightness(image, delta, min_value=None, max_value=None)` {#adjust_brightness}
......@@ -833,6 +955,126 @@ picked in the interval `[lower, upper]`.
- - -
### `tf.image.adjust_hue(image, delta, name=None)` {#adjust_hue}
Adjust hue of an RGB image.
This is a convenience method that converts an RGB image to float
representation, converts it to HSV, add an offset to the hue channel, converts
back to RGB and then back to the original data type. If several adjustments
are chained it is advisable to minimize the number of redundant conversions.
`image` is an RGB image. The image hue is adjusted by converting the
image to HSV and rotating the hue channel (H) by
`delta`. The image is then converted back to RGB.
`delta` must be in the interval `[-1, 1]`.
##### Args:
* <b>`image`</b>: RGB image or images. Size of the last dimension must be 3.
* <b>`delta`</b>: float. How much to add to the hue channel.
* <b>`name`</b>: A name for this operation (optional).
##### Returns:
Adjusted image(s), same shape and DType as `image`.
- - -
### `tf.image.random_hue(image, max_delta, seed=None)` {#random_hue}
Adjust the hue of an RGB image by a random factor.
Equivalent to `adjust_hue()` but uses a `delta` randomly
picked in the interval `[-max_delta, max_delta]`.
`max_delta` must be in the interval `[0, 0.5]`.
##### Args:
* <b>`image`</b>: RGB image or images. Size of the last dimension must be 3.
* <b>`max_delta`</b>: float. Maximum value for the random delta.
* <b>`seed`</b>: An operation-specific seed. It will be used in conjunction
with the graph-level seed to determine the real seeds that will be
used in this operation. Please see the documentation of
set_random_seed for its interaction with the graph-level random seed.
##### Returns:
3-D float tensor of shape `[height, width, channels]`.
##### Raises:
* <b>`ValueError`</b>: if `max_delta` is invalid.
- - -
### `tf.image.adjust_saturation(image, saturation_factor, name=None)` {#adjust_saturation}
Adjust staturation of an RGB image.
This is a convenience method that converts an RGB image to float
representation, converts it to HSV, add an offset to the saturation channel,
converts back to RGB and then back to the original data type. If several
adjustments are chained it is advisable to minimize the number of redundant
conversions.
`image` is an RGB image. The image saturation is adjusted by converting the
image to HSV and multiplying the saturation (S) channel by
`saturation_factor` and clipping. The image is then converted back to RGB.
##### Args:
* <b>`image`</b>: RGB image or images. Size of the last dimension must be 3.
* <b>`saturation_factor`</b>: float. Factor to multiply the saturation by.
* <b>`name`</b>: A name for this operation (optional).
##### Returns:
Adjusted image(s), same shape and DType as `image`.
- - -
### `tf.image.random_saturation(image, lower, upper, seed=None)` {#random_saturation}
Adjust the saturation of an RGB image by a random factor.
Equivalent to `adjust_saturation()` but uses a `saturation_factor` randomly
picked in the interval `[lower, upper]`.
##### Args:
* <b>`image`</b>: RGB image or images. Size of the last dimension must be 3.
* <b>`lower`</b>: float. Lower bound for the random saturation factor.
* <b>`upper`</b>: float. Upper bound for the random saturation factor.
* <b>`seed`</b>: An operation-specific seed. It will be used in conjunction
with the graph-level seed to determine the real seeds that will be
used in this operation. Please see the documentation of
set_random_seed for its interaction with the graph-level random seed.
##### Returns:
Adjusted image(s), same shape and DType as `image`.
##### Raises:
* <b>`ValueError`</b>: if `upper <= lower` or if `lower < 0`.
- - -
### `tf.image.per_image_whitening(image)` {#per_image_whitening}
......
......@@ -198,6 +198,8 @@
* **[Images](../../api_docs/python/image.md)**:
* [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness)
* [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast)
* [`adjust_hue`](../../api_docs/python/image.md#adjust_hue)
* [`adjust_saturation`](../../api_docs/python/image.md#adjust_saturation)
* [`convert_image_dtype`](../../api_docs/python/image.md#convert_image_dtype)
* [`crop_to_bounding_box`](../../api_docs/python/image.md#crop_to_bounding_box)
* [`decode_jpeg`](../../api_docs/python/image.md#decode_jpeg)
......@@ -207,6 +209,8 @@
* [`extract_glimpse`](../../api_docs/python/image.md#extract_glimpse)
* [`flip_left_right`](../../api_docs/python/image.md#flip_left_right)
* [`flip_up_down`](../../api_docs/python/image.md#flip_up_down)
* [`grayscale_to_rgb`](../../api_docs/python/image.md#grayscale_to_rgb)
* [`hsv_to_rgb`](../../api_docs/python/image.md#hsv_to_rgb)
* [`pad_to_bounding_box`](../../api_docs/python/image.md#pad_to_bounding_box)
* [`per_image_whitening`](../../api_docs/python/image.md#per_image_whitening)
* [`random_brightness`](../../api_docs/python/image.md#random_brightness)
......@@ -214,6 +218,8 @@
* [`random_crop`](../../api_docs/python/image.md#random_crop)
* [`random_flip_left_right`](../../api_docs/python/image.md#random_flip_left_right)
* [`random_flip_up_down`](../../api_docs/python/image.md#random_flip_up_down)
* [`random_hue`](../../api_docs/python/image.md#random_hue)
* [`random_saturation`](../../api_docs/python/image.md#random_saturation)
* [`resize_area`](../../api_docs/python/image.md#resize_area)
* [`resize_bicubic`](../../api_docs/python/image.md#resize_bicubic)
* [`resize_bilinear`](../../api_docs/python/image.md#resize_bilinear)
......@@ -221,6 +227,8 @@
* [`resize_images`](../../api_docs/python/image.md#resize_images)
* [`resize_nearest_neighbor`](../../api_docs/python/image.md#resize_nearest_neighbor)
* [`resize_nearest_neighbor_grad`](../../api_docs/python/image.md#resize_nearest_neighbor_grad)
* [`rgb_to_grayscale`](../../api_docs/python/image.md#rgb_to_grayscale)
* [`rgb_to_hsv`](../../api_docs/python/image.md#rgb_to_hsv)
* [`transpose_image`](../../api_docs/python/image.md#transpose_image)
* **[Sparse Tensors](../../api_docs/python/sparse_ops.md)**:
......
......@@ -1138,7 +1138,7 @@ compute them approximately.
- - -
### `tf.nn.fixed_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, vocab_file='', distortion=0.0, num_reserved_ids=0, num_shards=1, shard=0, unigrams=[], seed=None, name=None)` {#fixed_unigram_candidate_sampler}
### `tf.nn.fixed_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, vocab_file='', distortion=1.0, num_reserved_ids=0, num_shards=1, shard=0, unigrams=[], seed=None, name=None)` {#fixed_unigram_candidate_sampler}
Samples a set of classes using the provided (fixed) base distribution.
......
......@@ -890,9 +890,9 @@ the constructor is used. If that one is `None` too, a
* <b>`dtype`</b>: type of the new or existing variable (defaults to `DT_FLOAT`).
* <b>`initializer`</b>: initializer for the variable if one is created.
* <b>`trainable`</b>: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable).
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
* <b>`collections`</b>: List of graph collections keys to add the Variable to.
Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).
Defaults to `[GraphKeys.VARIABLES]` (see tf.Variable).
##### Returns:
......
......@@ -30,7 +30,7 @@ class directly, but instead instantiate one of its subclasses such as
# Create an optimizer with the desired parameters.
opt = GradientDescentOptimizer(learning_rate=0.1)
# Add Ops to the graph to minimize a cost by updating a list of variables.
# "cost" is a Tensor, and the list of variables contains variables.Variable
# "cost" is a Tensor, and the list of variables contains tf.Variable
# objects.
opt_op = opt.minimize(cost, <list of variables>)
```
......@@ -145,7 +145,7 @@ given variable.
* <b>`loss`</b>: A Tensor containing the value to minimize.
* <b>`var_list`</b>: Optional list of variables.Variable to update to minimize
* <b>`var_list`</b>: Optional list of tf.Variable to update to minimize
`loss`. Defaults to the list of variables collected in the graph
under the key `GraphKey.TRAINABLE_VARIABLES`.
* <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be
......@@ -1412,9 +1412,10 @@ height, width, channels]` and where `channels` can be:
* 3: `tensor` is interpreted as RGB.
* 4: `tensor` is interpreted as RGBA.
The images have the same number of channels as the input tensor. Their values
are normalized, one image at a time, to fit in the range `[0, 255]`. The
op uses two different normalization algorithms:
The images have the same number of channels as the input tensor. For float
input, the values are normalized one image at a time to fit in the range
`[0, 255]`. `uint8` values are unchanged. The op uses two different
normalization algorithms:
* If the input values are all positive, they are rescaled so the largest one
is 255.
......@@ -1435,8 +1436,8 @@ build the `tag` of the summary values:
* <b>`tag`</b>: A scalar `Tensor` of type `string`. Used to build the `tag`
of the summary values.
* <b>`tensor`</b>: A 4-D `float32` `Tensor` of shape `[batch_size, height, width,
channels]` where `channels` is 1, 3, or 4.
* <b>`tensor`</b>: A 4-D `uint8` or `float32` `Tensor` of shape `[batch_size, height,
width, channels]` where `channels` is 1, 3, or 4.
* <b>`max_images`</b>: Max number of batch elements to generate images for.
* <b>`collections`</b>: Optional list of ops.GraphKeys. The collections to add the
summary to. Defaults to [ops.GraphKeys.SUMMARIES]
......@@ -1464,7 +1465,7 @@ This op reports an `OutOfRange` error if any value is not finite.
* <b>`tag`</b>: A `string` `Tensor`. 0-D. Tag to use for the summary value.
* <b>`values`</b>: A `float32` or `float64` `Tensor`. Any shape. Values to use to
* <b>`values`</b>: A real numeric `Tensor`. Any shape. Values to use to
build the histogram.
* <b>`collections`</b>: Optional list of graph collections keys. The new summary op is
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
......
......@@ -5,8 +5,8 @@ github source.
## Requirements
The TensorFlow Python API currently requires Python 2.7. We are
[adding support for Python 3](https://github.com/tensorflow/tensorflow/issues/1).
The TensorFlow Python API currently supports Python 2.7 and Python 3.3+ from
source. We are preparing Python 3 pip packages to go with the 0.6.0 release.
The GPU version (Linux only) currently requires the Cuda Toolkit 7.0 and CUDNN
6.5 V2. Please see [Cuda installation](#install_cuda).
......
......@@ -6,6 +6,20 @@ answer on one of the TensorFlow [community resources](../resources/index.md).
[TOC]
## Features and Compatibility
#### Can I run distributed training on multiple computers?
The initial open-source release of TensorFlow supports multiple devices (CPUs
and GPUs) in a single computer. We are actively working on an open-source
multi-machine version, and plan to release it as soon as it's ready. You can
follow progress at the [GitHub issue](https://github.com/tensorflow/tensorflow/issues/23).
#### Does TensorFlow work with Python 3?
As of the 0.6.0 release timeframe (Early December 2015), we do support Python
3.3+.
## Building a TensorFlow graph
See also the
......@@ -107,12 +121,6 @@ The intermediate tensors that are created as part of a call to
[`Session.run()`](../api_docs/python/client.md) will be freed at or before the
end of the call.
#### Can I run distributed training on multiple computers?
The initial open-source release of TensorFlow supports multiple devices (CPUs
and GPUs) in a single computer. We are working on a distributed version as well:
if you are interested, please let us know so we can prioritize accordingly.
#### Does the runtime parallelize parts of graph execution?
The TensorFlow runtime parallelizes graph execution across many different
......@@ -295,12 +303,6 @@ for more details of how to define these different input types.
## Miscellaneous
#### Does TensorFlow work with Python 3?
We have only tested TensorFlow using Python 2.7. We are aware of some changes
that will be required for Python 3 compatibility, and welcome contributions
towards this effort.
#### What is TensorFlow's coding style convention?
The TensorFlow Python API adheres to the
......
......@@ -122,32 +122,46 @@ cc_library(
],
)
# What is needed for tf_gen_op_wrapper_py.
py_library(
name = "framework",
name = "framework_for_generated_wrappers",
srcs = [
# TODO(mrry): Move this to framework.
"client/graph_util.py",
"framework/device.py",
"framework/errors.py",
"framework/framework_lib.py",
"framework/importer.py",
"framework/dtypes.py",
"framework/op_def_registry.py",
"framework/ops.py",
"framework/random_seed.py",
"framework/registry.py",
"framework/tensor_shape.py",
"framework/dtypes.py",
"framework/tensor_util.py",
"framework/versions.py",
"ops/common_shapes.py",
# TODO(josh11b): Move this to the framework directory
"ops/op_def_library.py",
"ops/constant_op.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":platform",
":util",
"//tensorflow/core:protos_all_py",
],
)
py_library(
name = "framework",
srcs = [
# TODO(mrry): Move this to the framework directory.
"client/graph_util.py",
"framework/errors.py",
"framework/framework_lib.py",
"framework/importer.py",
"framework/random_seed.py",
"framework/tensor_util.py",
# TODO(josh11b): Move this to the framework directory
"ops/common_shapes.py",
],
srcs_version = "PY2AND3",
deps = [":framework_for_generated_wrappers"],
)
# subinclude("//third_party/py/cython:build_defs")
py_library(
......@@ -402,6 +416,8 @@ tf_gen_op_wrapper_py(
tf_gen_op_wrapper_py(
name = "image_ops",
hidden = [
"ResizeBilinearGrad",
"ResizeNearestNeighborGrad",
"ScaleImageGrad",
],
require_shape_functions = True,
......@@ -527,6 +543,9 @@ tf_gen_op_wrapper_py(
tf_gen_op_wrapper_py(
name = "sparse_ops",
hidden = [
"DeserializeManySparse",
"SerializeManySparse",
"SerializeSparse",
"SparseConcat",
"SparseSelectLastK",
"SparseReorder",
......@@ -574,7 +593,6 @@ py_library(
"ops/attention_ops.py",
"ops/candidate_sampling_ops.py",
"ops/clip_ops.py",
"ops/constant_op.py",
"ops/control_flow_grad.py",
"ops/control_flow_ops.py",
"ops/data_flow_grad.py",
......@@ -608,7 +626,6 @@ py_library(
"ops/nn_grad.py",
"ops/nn_ops.py",
"ops/numerics.py",
"ops/op_def_library.py",
"ops/parsing_ops.py",
"ops/random_ops.py",
"ops/rnn.py",
......
......@@ -90,6 +90,14 @@ class SessionTest(test_util.TensorFlowTestCase):
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
def testPerSessionThreads(self):
# TODO(keveman): Implement ListDevices and test for the number of
# devices returned by ListDevices.
with session.Session(
config=config_pb2.ConfigProto(use_per_session_threads=True)):
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
def testErrorsReported(self):
with session.Session() as s:
constant_op.constant(10.0, name='W1')
......
......@@ -80,7 +80,10 @@ def all_libraries(module_to_name, members, documented):
library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT),
library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"],
prefix=PREFIX_TEXT),
library("sparse_ops", "Sparse Tensors", prefix=PREFIX_TEXT),
library("sparse_ops", "Sparse Tensors",
exclude_symbols=["serialize_sparse", "serialize_many_sparse",
"deserialize_many_sparse"],
prefix=PREFIX_TEXT),
library("io_ops", "Inputs and Readers",
exclude_symbols=["LookupTableBase", "HashTable",
"initialize_all_tables",
......
......@@ -3108,10 +3108,10 @@ class GraphKeys(object):
for more details.
"""
# Key to collect variables.Variable objects that must be saved and restored
# Key to collect Variable objects that must be saved and restored
# by the model.
VARIABLES = "variables"
# Key to collect variables.Variable objects that will be trained by the
# Key to collect Variable objects that will be trained by the
# optimizers.
TRAINABLE_VARIABLES = "trainable_variables"
# Key to collect summaries.
......
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for SerializeSparse."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order,unused-import
import tensorflow.python.platform
import numpy as np
import tensorflow as tf
class SerializeSparseTest(tf.test.TestCase):
def _SparseTensorPlaceholder(self, dtype=None):
if dtype is None: dtype = tf.int32
return tf.SparseTensor(
tf.placeholder(tf.int64),
tf.placeholder(dtype),
tf.placeholder(tf.int64))
def _SparseTensorValue_5x6(self, permutation):
ind = np.array([
[0, 0],
[1, 0], [1, 3], [1, 4],
[3, 2], [3, 3]]).astype(np.int64)
val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
ind = ind[permutation]
val = val[permutation]
shape = np.array([5, 6]).astype(np.int64)
return tf.SparseTensorValue(ind, val, shape)
def _SparseTensorValue_3x4(self, permutation):
ind = np.array([
[0, 0],
[1, 0], [1, 2], [1, 3],
[2, 2], [2, 3]]).astype(np.int64)
val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
ind = ind[permutation]
val = val[permutation]
shape = np.array([3, 4]).astype(np.int64)
return tf.SparseTensorValue(ind, val, shape)
def _SparseTensorValue_1x1x1(self):
ind = np.array([[0, 0, 0]]).astype(np.int64)
val = np.array([0]).astype(np.int32)
shape = np.array([3, 4, 5]).astype(np.int64)
return tf.SparseTensorValue(ind, val, shape)
def testSerializeDeserializeMany(self):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6))
input1_val = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = tf.serialize_sparse(sp_input0)
serialized1 = tf.serialize_sparse(sp_input1)
serialized_concat = tf.pack([serialized0, serialized1])
sp_deserialized = tf.deserialize_many_sparse(
serialized_concat, dtype=tf.int32)
combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized, {sp_input0: input0_val, sp_input1: input1_val})
self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0
self.assertAllEqual(combined_indices[:6, 1:], input0_val[0])
self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1
self.assertAllEqual(combined_indices[6:, 1:], input1_val[0])
self.assertAllEqual(combined_values[:6], input0_val[1])
self.assertAllEqual(combined_values[6:], input1_val[1])
self.assertAllEqual(combined_shape, [2, 5, 6])
def testSerializeManyDeserializeManyRoundTrip(self):
with self.test_session(use_gpu=False) as sess:
# N == 4 because shape_value == [4, 5]
indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
values_value = np.array(["a", "b", "c"])
shape_value = np.array([4, 5], dtype=np.int64)
sparse_tensor = self._SparseTensorPlaceholder(dtype=tf.string)
serialized = tf.serialize_many_sparse(sparse_tensor)
deserialized = tf.deserialize_many_sparse(serialized, dtype=tf.string)
serialized_value, deserialized_value = sess.run(
[serialized, deserialized],
feed_dict={sparse_tensor.indices: indices_value,
sparse_tensor.values: values_value,
sparse_tensor.shape: shape_value})
self.assertEqual(serialized_value.shape, (4, 3))
self.assertAllEqual(deserialized_value.indices, indices_value)
self.assertAllEqual(deserialized_value.values, values_value)
self.assertAllEqual(deserialized_value.shape, shape_value)
def testDeserializeFailsWrongType(self):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6))
input1_val = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = tf.serialize_sparse(sp_input0)
serialized1 = tf.serialize_sparse(sp_input1)
serialized_concat = tf.pack([serialized0, serialized1])
sp_deserialized = tf.deserialize_many_sparse(
serialized_concat, dtype=tf.int64)
with self.assertRaisesOpError(
r"Requested SparseTensor of type int64 but "
r"SparseTensor\[0\].values.dtype\(\) == int32"):
sess.run(
sp_deserialized, {sp_input0: input0_val, sp_input1: input1_val})
def testDeserializeFailsInconsistentRank(self):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6))
input1_val = self._SparseTensorValue_1x1x1()
serialized0 = tf.serialize_sparse(sp_input0)
serialized1 = tf.serialize_sparse(sp_input1)
serialized_concat = tf.pack([serialized0, serialized1])
sp_deserialized = tf.deserialize_many_sparse(
serialized_concat, dtype=tf.int32)
with self.assertRaisesOpError(
r"Inconsistent rank across SparseTensors: rank prior to "
r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"):
sess.run(
sp_deserialized, {sp_input0: input0_val, sp_input1: input1_val})
def testDeserializeFailsInvalidProto(self):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6))
serialized0 = tf.serialize_sparse(sp_input0)
serialized1 = ["a", "b", "c"]
serialized_concat = tf.pack([serialized0, serialized1])
sp_deserialized = tf.deserialize_many_sparse(
serialized_concat, dtype=tf.int32)
with self.assertRaisesOpError(
r"Could not parse serialized_sparse\[1, 0\]"):
sess.run(sp_deserialized, {sp_input0: input0_val})
if __name__ == "__main__":
tf.test.main()
......@@ -34,6 +34,18 @@ class SummaryImageOpTest(tf.test.TestCase):
summ.ParseFromString(s)
return summ
def _CheckProto(self, image_summ, shape):
"""Verify that the non-image parts of the image_summ proto match shape."""
# Only the first 3 images are returned.
for v in image_summ.value:
v.image.ClearField("encoded_image_string")
expected = '\n'.join("""
value {
tag: "img/image/%d"
image { height: %d width: %d colorspace: %d }
}""" % ((i,) + shape[1:]) for i in xrange(3))
self.assertProtoEquals(expected, image_summ)
def testImageSummary(self):
np.random.seed(7)
with self.test_session() as sess:
......@@ -42,7 +54,7 @@ class SummaryImageOpTest(tf.test.TestCase):
bad_color = [255, 0, 0, 255][:depth]
for positive in False, True:
# Build a mostly random image with one nan
const = np.random.randn(*shape)
const = np.random.randn(*shape).astype(np.float32)
const[0, 1, 2] = 0 # Make the nan entry not the max
if positive:
const = 1 + np.maximum(const, 0)
......@@ -68,15 +80,33 @@ class SummaryImageOpTest(tf.test.TestCase):
self.assertAllClose(image, adjusted[0])
# Check the rest of the proto
# Only the first 3 images are returned.
for v in image_summ.value:
v.image.ClearField("encoded_image_string")
expected = '\n'.join("""
value {
tag: "img/image/%d"
image { height: %d width: %d colorspace: %d }
}""" % ((i,) + shape[1:]) for i in xrange(3))
self.assertProtoEquals(expected, image_summ)
self._CheckProto(image_summ, shape)
def testImageSummaryUint8(self):
np.random.seed(7)
with self.test_session() as sess:
for depth in 1, 3, 4:
shape = (4, 5, 7) + (depth,)
# Build a random uint8 image
images = np.random.randint(256, size=shape).astype(np.uint8)
tf_images = tf.convert_to_tensor(images)
self.assertEqual(tf_images.dtype, tf.uint8)
# Summarize
summ = tf.image_summary("img", tf_images)
value = sess.run(summ)
self.assertEqual([], summ.get_shape())
image_summ = self._AsSummary(value)
# Decode the first image and check consistency.
# Since we're uint8, everything should be exact.
image = image_ops.decode_png(
image_summ.value[0].image.encoded_image_string).eval()
self.assertAllEqual(image, images[0])
# Check the rest of the proto
self._CheckProto(image_summ, shape)
if __name__ == "__main__":
......
......@@ -97,6 +97,13 @@ class SummaryOpsTest(tf.test.TestCase):
self.assertEqual(summ2, merge.op.inputs[0])
self.assertTrue(tf.merge_all_summaries("bar_key") is None)
def testHistogramSummaryTypes(self):
with tf.Graph().as_default():
for dtype in (tf.int8, tf.uint8, tf.int16, tf.int32,
tf.float32, tf.float64):
const = tf.constant(10, dtype=dtype)
tf.histogram_summary("h", const, name="histo")
if __name__ == "__main__":
tf.test.main()
......@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains Gradient functions for image ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
......@@ -36,11 +36,33 @@ def _ResizeNearestNeighborGrad(op, grad):
Returns:
The gradients w.r.t. the input and the output.
"""
grads = gen_image_ops.resize_nearest_neighbor_grad(
# pylint: disable=protected-access
grads = gen_image_ops._resize_nearest_neighbor_grad(
grad, op.inputs[0].get_shape()[1:3])
# pylint: enable=protected-access
return [grads, None]
@ops.RegisterGradient("ResizeBilinear")
def _ResizeBilinearGrad(op, grad):
"""The derivatives for bilinear resizing.
Args:
op: The ResizeBilinear op.
grad: The tensor representing the gradient w.r.t. the output.
Returns:
The gradients w.r.t. the input.
"""
allowed_types = [dtypes.float32, dtypes.float64]
grad0 = None
if op.inputs[0].dtype in allowed_types:
# pylint: disable=protected-access
grad0 = gen_image_ops._resize_bilinear_grad(grad, op.inputs[0])
# pylint: enable=protected-access
return [grad0, None]
@ops.RegisterShape("ResizeNearestNeighborGrad")
def _ResizeShape(op):
"""Shape function for the resize grad ops."""
......@@ -55,3 +77,10 @@ def _ResizeShape(op):
return [
tensor_shape.TensorShape([input_shape[0], height, width, input_shape[3]])
]
@ops.RegisterShape("ResizeBilinearGrad")
def _ResizeBilinearGradShape(op):
"""Shape function for ResizeBilinearGrad."""
return [op.inputs[1].get_shape()]
......@@ -80,5 +80,68 @@ class ResizeNearestNeighborOpTest(tf.test.TestCase):
self.assertLess(err, 1e-3)
class ResizeBilinearOpTest(tf.test.TestCase):
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
with self.test_session() as sess:
input_tensor = tf.constant(x, shape=in_shape)
resize_out = tf.image.resize_bilinear(input_tensor,
out_shape[1:3])
self.assertEqual(out_shape, list(resize_out.get_shape()))
resize_out = sess.run(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
def testGradFromResizeToLargerInBothDims(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
with self.test_session():
input_tensor = tf.constant(x, shape=in_shape)
resize_out = tf.image.resize_bilinear(input_tensor,
out_shape[1:3])
err = tf.test.compute_gradient_error(input_tensor,
in_shape,
resize_out,
out_shape,
x_init_value=x)
self.assertLess(err, 1e-3)
def testGradFromResizeToSmallerInBothDims(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
with self.test_session():
input_tensor = tf.constant(x, shape=in_shape)
resize_out = tf.image.resize_bilinear(input_tensor,
out_shape[1:3])
err = tf.test.compute_gradient_error(input_tensor,
in_shape,
resize_out,
out_shape,
x_init_value=x)
self.assertLess(err, 1e-3)
def testGradOnUnsupportedType(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
with self.test_session():
input_tensor = tf.constant(x, shape=in_shape)
resize_out = tf.image.resize_bilinear(input_tensor, out_shape[1:3])
grad = tf.gradients(input_tensor, [resize_out])
self.assertEqual([None], grad)
if __name__ == "__main__":
tf.test.main()
......@@ -85,25 +85,67 @@ resized_image = tf.image.resize_images(image, 299, 299)
## Converting Between Colorspaces.
Image ops work either on individual images or on batches of images, depending on
the shape of their input Tensor.
If 3-D, the shape is `[height, width, channels]`, and the Tensor represents one
image. If 4-D, the shape is `[batch_size, height, width, channels]`, and the
Tensor represents `batch_size` images.
Currently, `channels` can usefully be 1, 2, 3, or 4. Single-channel images are
grayscale, images with 3 channels are encoded as either RGB or HSV. Images
with 2 or 4 channels include an alpha channel, which has to be stripped from the
image before passing the image to most image processing functions (and can be
re-attached later).
Internally, images are either stored in as one `float32` per channel per pixel
(implicitly, values are assumed to lie in `[0,1)`) or one `uint8` per channel
per pixel (values are assumed to lie in `[0,255]`).
Tensorflow can convert between images in RGB or HSV. The conversion functions
work only on float images, so you need to convert images in other formats using
[`convert_image_dtype`](#convert-image-dtype).
Example:
```python
# Decode an image and convert it to HSV.
rgb_image = tf.decode_png(..., channels=3)
rgb_image_float = tf.convert_image_dtype(rgb_image, tf.float32)
hsv_image = tf.hsv_to_rgb(rgb_image)
```
@@rgb_to_grayscale
@@grayscale_to_rgb
@@hsv_to_rgb
@@rgb_to_hsv
@@convert_image_dtype
## Image Adjustments
TensorFlow provides functions to adjust images in various ways: brightness,
contrast, hue, and saturation. Each adjustment can be done with predefined
parameters or with random parameters picked from predefined intervals. Random
parameters or with random parameters picked from predefined intervals. Random
adjustments are often useful to expand a training set and reduce overfitting.
If several adjustments are chained it is advisable to minimize the number of
redundant conversions by first converting the images to the most natural data
type and representation (RGB or HSV).
@@adjust_brightness
@@random_brightness
@@adjust_contrast
@@random_contrast
@@adjust_hue
@@random_hue
@@adjust_saturation
@@random_saturation
@@per_image_whitening
"""
from __future__ import absolute_import
......@@ -133,8 +175,9 @@ from tensorflow.python.ops.gen_image_ops import *
from tensorflow.python.ops.gen_attention_ops import *
# pylint: enable=wildcard-import
ops.NoGradient('ResizeBilinear')
ops.NoGradient('RandomCrop')
ops.NoGradient('RGBToHSV')
ops.NoGradient('HSVToRGB')
def _ImageDimensions(images):
......@@ -875,3 +918,215 @@ def convert_image_dtype(image, dtype, name=None):
scale = dtype.max + 0.5 # avoid rounding problems in the cast
scaled = math_ops.mul(image, scale)
return math_ops.cast(scaled, dtype)
def rgb_to_grayscale(images):
"""Converts one or more images from RGB to Grayscale.
Outputs a tensor of the same `DType` and rank as `images`. The size of the
last dimension of the output is 1, containing the Grayscale value of the
pixels.
Args:
images: The RGB tensor to convert. Last dimension must have size 3 and
should contain RGB values.
Returns:
The converted grayscale image(s).
"""
with ops.op_scope([images], None, 'rgb_to_grayscale'):
# Remember original dtype to so we can convert back if needed
orig_dtype = images.dtype
flt_image = convert_image_dtype(images, dtypes.float32)
# Reference for converting between RGB and grayscale.
# https://en.wikipedia.org/wiki/Luma_%28video%29
rgb_weights = [0.2989, 0.5870, 0.1140]
rank_1 = array_ops.expand_dims(array_ops.rank(images) - 1, 0)
gray_float = math_ops.reduce_sum(flt_image * rgb_weights,
rank_1,
keep_dims=True)
return convert_image_dtype(gray_float, orig_dtype)
def grayscale_to_rgb(images):
"""Converts one or more images from Grayscale to RGB.
Outputs a tensor of the same `DType` and rank as `images`. The size of the
last dimension of the output is 3, containing the RGB value of the pixels.
Args:
images: The Grayscale tensor to convert. Last dimension must be size 1.
Returns:
The converted grayscale image(s).
"""
with ops.op_scope([images], None, 'grayscale_to_rgb'):
rank_1 = array_ops.expand_dims(array_ops.rank(images) - 1, 0)
shape_list = (
[array_ops.ones(rank_1,
dtype=dtypes.int32)] + [array_ops.expand_dims(3, 0)])
multiples = array_ops.concat(0, shape_list)
return array_ops.tile(images, multiples)
# pylint: disable=invalid-name
@ops.RegisterShape('HSVToRGB')
@ops.RegisterShape('RGBToHSV')
def _ColorspaceShape(op):
"""Shape function for colorspace ops."""
input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
input_rank = input_shape.ndims
if input_rank is not None:
input_shape = input_shape.merge_with([None] * (input_rank - 1) + [3])
return [input_shape]
# pylint: enable=invalid-name
def random_hue(image, max_delta, seed=None):
"""Adjust the hue of an RGB image by a random factor.
Equivalent to `adjust_hue()` but uses a `delta` randomly
picked in the interval `[-max_delta, max_delta]`.
`max_delta` must be in the interval `[0, 0.5]`.
Args:
image: RGB image or images. Size of the last dimension must be 3.
max_delta: float. Maximum value for the random delta.
seed: An operation-specific seed. It will be used in conjunction
with the graph-level seed to determine the real seeds that will be
used in this operation. Please see the documentation of
set_random_seed for its interaction with the graph-level random seed.
Returns:
3-D float tensor of shape `[height, width, channels]`.
Raises:
ValueError: if `max_delta` is invalid.
"""
if max_delta > 0.5:
raise ValueError('max_delta must be <= 0.5.')
if max_delta < 0:
raise ValueError('max_delta must be non-negative.')
delta = random_ops.random_uniform([], -max_delta, max_delta, seed=seed)
return adjust_hue(image, delta)
def adjust_hue(image, delta, name=None):
"""Adjust hue of an RGB image.
This is a convenience method that converts an RGB image to float
representation, converts it to HSV, add an offset to the hue channel, converts
back to RGB and then back to the original data type. If several adjustments
are chained it is advisable to minimize the number of redundant conversions.
`image` is an RGB image. The image hue is adjusted by converting the
image to HSV and rotating the hue channel (H) by
`delta`. The image is then converted back to RGB.
`delta` must be in the interval `[-1, 1]`.
Args:
image: RGB image or images. Size of the last dimension must be 3.
delta: float. How much to add to the hue channel.
name: A name for this operation (optional).
Returns:
Adjusted image(s), same shape and DType as `image`.
"""
with ops.op_scope([image], name, 'adjust_hue') as name:
# Remember original dtype to so we can convert back if needed
orig_dtype = image.dtype
flt_image = convert_image_dtype(image, dtypes.float32)
hsv = gen_image_ops.rgb_to_hsv(flt_image)
hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])
# Note that we add 2*pi to guarantee that the resulting hue is a positive
# floating point number since delta is [-0.5, 0.5].
hue = math_ops.mod(hue + (delta + 1.), 1.)
hsv_altered = array_ops.concat(2, [hue, saturation, value])
rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)
return convert_image_dtype(rgb_altered, orig_dtype)
def random_saturation(image, lower, upper, seed=None):
"""Adjust the saturation of an RGB image by a random factor.
Equivalent to `adjust_saturation()` but uses a `saturation_factor` randomly
picked in the interval `[lower, upper]`.
Args:
image: RGB image or images. Size of the last dimension must be 3.
lower: float. Lower bound for the random saturation factor.
upper: float. Upper bound for the random saturation factor.
seed: An operation-specific seed. It will be used in conjunction
with the graph-level seed to determine the real seeds that will be
used in this operation. Please see the documentation of
set_random_seed for its interaction with the graph-level random seed.
Returns:
Adjusted image(s), same shape and DType as `image`.
Raises:
ValueError: if `upper <= lower` or if `lower < 0`.
"""
if upper <= lower:
raise ValueError('upper must be > lower.')
if lower < 0:
raise ValueError('lower must be non-negative.')
# Pick a float in [lower, upper]
saturation_factor = random_ops.random_uniform([], lower, upper, seed=seed)
return adjust_saturation(image, saturation_factor)
def adjust_saturation(image, saturation_factor, name=None):
"""Adjust staturation of an RGB image.
This is a convenience method that converts an RGB image to float
representation, converts it to HSV, add an offset to the saturation channel,
converts back to RGB and then back to the original data type. If several
adjustments are chained it is advisable to minimize the number of redundant
conversions.
`image` is an RGB image. The image saturation is adjusted by converting the
image to HSV and multiplying the saturation (S) channel by
`saturation_factor` and clipping. The image is then converted back to RGB.
Args:
image: RGB image or images. Size of the last dimension must be 3.
saturation_factor: float. Factor to multiply the saturation by.
name: A name for this operation (optional).
Returns:
Adjusted image(s), same shape and DType as `image`.
"""
with ops.op_scope([image], name, 'adjust_saturation') as name:
# Remember original dtype to so we can convert back if needed
orig_dtype = image.dtype
flt_image = convert_image_dtype(image, dtypes.float32)
hsv = gen_image_ops.rgb_to_hsv(flt_image)
hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])
saturation *= saturation_factor
saturation = clip_ops.clip_by_value(saturation, 0.0, 1.0)
hsv_altered = array_ops.concat(2, [hue, saturation, value])
rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)
return convert_image_dtype(rgb_altered, orig_dtype)
......@@ -27,12 +27,179 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import test_util
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import googletest
class RGBToHSVTest(test_util.TensorFlowTestCase):
def testBatch(self):
# Build an arbitrary RGB image
np.random.seed(7)
batch_size = 5
shape = (batch_size, 2, 7, 3)
inp = np.random.rand(*shape).astype(np.float32)
# Convert to HSV and back, as a batch and individually
with self.test_session() as sess:
batch0 = constant_op.constant(inp)
batch1 = image_ops.rgb_to_hsv(batch0)
batch2 = image_ops.hsv_to_rgb(batch1)
split0 = array_ops.unpack(batch0)
split1 = map(image_ops.rgb_to_hsv, split0)
split2 = map(image_ops.hsv_to_rgb, split1)
join1 = array_ops.pack(split1)
join2 = array_ops.pack(split2)
batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
# Verify that processing batch elements together is the same as separate
self.assertAllClose(batch1, join1)
self.assertAllClose(batch2, join2)
self.assertAllClose(batch2, inp)
def testRGBToHSVRoundTrip(self):
data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
rgb_np = np.array(data, dtype=np.float32).reshape([2, 2, 3]) / 255.
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu):
hsv = image_ops.rgb_to_hsv(rgb_np)
rgb = image_ops.hsv_to_rgb(hsv)
rgb_tf = rgb.eval()
self.assertAllClose(rgb_tf, rgb_np)
class GrayscaleToRGBTest(test_util.TensorFlowTestCase):
def _RGBToGrayscale(self, images):
is_batch = True
if len(images.shape) == 3:
is_batch = False
images = np.expand_dims(images, axis=0)
out_shape = images.shape[0:3] + (1,)
out = np.zeros(shape=out_shape, dtype=np.uint8)
for batch in xrange(images.shape[0]):
for y in xrange(images.shape[1]):
for x in xrange(images.shape[2]):
red = images[batch, y, x, 0]
green = images[batch, y, x, 1]
blue = images[batch, y, x, 2]
gray = 0.2989 * red + 0.5870 * green + 0.1140 * blue
out[batch, y, x, 0] = int(gray)
if not is_batch:
out = np.squeeze(out, axis=0)
return out
def _TestRGBToGrayscale(self, x_np):
y_np = self._RGBToGrayscale(x_np)
with self.test_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.rgb_to_grayscale(x_tf)
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
def testBasicRGBToGrayscale(self):
# 4-D input with batch dimension.
x_np = np.array([[1, 2, 3], [4, 10, 1]],
dtype=np.uint8).reshape([1, 1, 2, 3])
self._TestRGBToGrayscale(x_np)
# 3-D input with no batch dimension.
x_np = np.array([[1, 2, 3], [4, 10, 1]], dtype=np.uint8).reshape([1, 2, 3])
self._TestRGBToGrayscale(x_np)
def testBasicGrayscaleToRGB(self):
# 4-D input with batch dimension.
x_np = np.array([[1, 2]], dtype=np.uint8).reshape([1, 1, 2, 1])
y_np = np.array([[1, 1, 1], [2, 2, 2]],
dtype=np.uint8).reshape([1, 1, 2, 3])
with self.test_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.grayscale_to_rgb(x_tf)
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
# 3-D input with no batch dimension.
x_np = np.array([[1, 2]], dtype=np.uint8).reshape([1, 2, 1])
y_np = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.uint8).reshape([1, 2, 3])
with self.test_session():
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.grayscale_to_rgb(x_tf)
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
class AdjustHueTest(test_util.TensorFlowTestCase):
def testAdjustNegativeHue(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
delta = -0.25
y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
with self.test_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_hue(x, delta)
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
def testAdjustPositiveHue(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
delta = 0.25
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
with self.test_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_hue(x, delta)
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
class AdjustSaturationTest(test_util.TensorFlowTestCase):
def testHalfSaturation(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
saturation_factor = 0.5
y_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
with self.test_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_saturation(x, saturation_factor)
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
def testTwiceSaturation(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
saturation_factor = 2.0
y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
with self.test_session():
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_saturation(x, saturation_factor)
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
class FlipTest(test_util.TensorFlowTestCase):
def testIdempotentLeftRight(self):
......
......@@ -513,3 +513,156 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
sp_ordered_output = sparse_reorder(sp_unordered_output)
return sp_ordered_output, empty_row_indicator
def serialize_sparse(sp_input, name=None):
"""Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object.
Args:
sp_input: The input `SparseTensor`.
name: A name prefix for the returned tensors (optional).
Returns:
A string 3-vector (1D `Tensor`), with each column representing the
serialized `SparseTensor`'s indices, values, and shape (respectively).
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
if not isinstance(sp_input, ops.SparseTensor):
raise TypeError("Input must be a SparseTensor.")
return gen_sparse_ops._serialize_sparse(
sp_input.indices,
sp_input.values,
sp_input.shape,
name=name)
@ops.RegisterShape("SerializeSparse")
def _SerializeSparseShape(op): # pylint: disable=invalid-name
"""Shape function for SerializeSparse op."""
op.inputs[0].get_shape().with_rank(2)
op.inputs[1].get_shape().with_rank(1)
op.inputs[2].get_shape().with_rank(1)
return [tensor_shape.vector(3)]
def serialize_many_sparse(sp_input, name=None):
"""Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`.
The `SparseTensor` must have rank `R` greater than 1, and the first dimension
is treated as the minibatch dimension. Elements of the `SparseTensor`
must be sorted in increasing order of this first dimension. The serialized
`SparseTensor` objects going into each row of the output `Tensor` will have
rank `R-1`.
The minibatch size `N` is extracted from `sparse_shape[0]`.
Args:
sp_input: The input rank `R` `SparseTensor`.
name: A name prefix for the returned tensors (optional).
Returns:
A string matrix (2-D `Tensor`) with `N` rows and `3` columns.
Each column represents serialized `SparseTensor`'s indices, values, and
shape (respectively).
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
if not isinstance(sp_input, ops.SparseTensor):
raise TypeError("Input must be a SparseTensor.")
return gen_sparse_ops._serialize_many_sparse(
sp_input.indices,
sp_input.values,
sp_input.shape,
name=name)
@ops.RegisterShape("SerializeManySparse")
def _SerializeManySparseShape(op): # pylint: disable=invalid-name
"""Shape function for SerializeSparse op."""
op.inputs[0].get_shape().with_rank(2)
op.inputs[1].get_shape().with_rank(1)
op.inputs[2].get_shape().with_rank(1)
return [tensor_shape.matrix(None, 3)]
def deserialize_many_sparse(serialized_sparse, dtype, name=None):
"""Deserialize and concatenate `SparseTensors` from a serialized minibatch.
The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where
`N` is the minibatch size and the rows correspond to packed outputs of
`serialize_sparse`. The ranks of the original `SparseTensor` objects
must all match. When the final `SparseTensor` is created, it has rank one
higher than the ranks of the incoming `SparseTensor` objects (they have been
concatenated along a new row dimension).
The output `SparseTensor` object's shape values for all dimensions but the
first are the max across the input `SparseTensor` objects' shape values
for the corresponding dimensions. Its first shape value is `N`, the minibatch
size.
The input `SparseTensor` objects' indices are assumed ordered in
standard lexicographic order. If this is not the case, after this
step run `sparse_reorder` to restore index ordering.
For example, if the serialized input is a `[2, 3]` matrix representing two
original `SparseTensor` objects:
index = [ 0]
[10]
[20]
values = [1, 2, 3]
shape = [50]
and
index = [ 2]
[10]
values = [4, 5]
shape = [30]
then the final deserialized `SparseTensor` will be:
index = [0 0]
[0 10]
[0 20]
[1 2]
[1 10]
values = [1, 2, 3, 4, 5]
shape = [2 50]
Args:
serialized_sparse: 2-D `Tensor` of type `string` of shape `[N, 3]`.
The serialized and packed `SparseTensor' objects.
dtype: The `dtype` of the serialized `SparseTensor` objects.
name: A name prefix for the returned tensors (optional)
Returns:
A `SparseTensor` representing the deserialized `SparseTensor`s,
concatenated along the `SparseTensor`s' first dimension.
All of the serialized `SparseTensor`s must have had the same rank and type.
"""
output_indices, output_values, output_shape = (
gen_sparse_ops._deserialize_many_sparse(
serialized_sparse, dtype, name=name))
return ops.SparseTensor(output_indices, output_values, output_shape)
@ops.RegisterShape("DeserializeManySparse")
def _DeserializeSparseShape(op): # pylint: disable=invalid-name
"""Shape function for DeserializeManySparse op."""
serialized_sparse_shape = op.inputs[0].get_shape().with_rank(2)
serialized_sparse_shape.merge_with(
tensor_shape.TensorShape([None, 3]))
return [tensor_shape.matrix(None, None),
tensor_shape.vector(None),
tensor_shape.vector(None)]
......@@ -43,7 +43,7 @@ def histogram_summary(tag, values, collections=None, name=None):
Args:
tag: A `string` `Tensor`. 0-D. Tag to use for the summary value.
values: A `float32` or `float64` `Tensor`. Any shape. Values to use to
values: A real numeric `Tensor`. Any shape. Values to use to
build the histogram.
collections: Optional list of graph collections keys. The new summary op is
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
......@@ -71,9 +71,10 @@ def image_summary(tag, tensor, max_images=None, collections=None, name=None):
* 3: `tensor` is interpreted as RGB.
* 4: `tensor` is interpreted as RGBA.
The images have the same number of channels as the input tensor. Their values
are normalized, one image at a time, to fit in the range `[0, 255]`. The
op uses two different normalization algorithms:
The images have the same number of channels as the input tensor. For float
input, the values are normalized one image at a time to fit in the range
`[0, 255]`. `uint8` values are unchanged. The op uses two different
normalization algorithms:
* If the input values are all positive, they are rescaled so the largest one
is 255.
......@@ -92,8 +93,8 @@ def image_summary(tag, tensor, max_images=None, collections=None, name=None):
Args:
tag: A scalar `Tensor` of type `string`. Used to build the `tag`
of the summary values.
tensor: A 4-D `float32` `Tensor` of shape `[batch_size, height, width,
channels]` where `channels` is 1, 3, or 4.
tensor: A 4-D `uint8` or `float32` `Tensor` of shape `[batch_size, height,
width, channels]` where `channels` is 1, 3, or 4.
max_images: Max number of batch elements to generate images for.
collections: Optional list of ops.GraphKeys. The collections to add the
summary to. Defaults to [ops.GraphKeys.SUMMARIES]
......
......@@ -69,9 +69,9 @@ class _VariableStore(object):
initializer: initializer for the variable.
reuse: a Boolean or `None`. Controls reuse or creation of variables.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable).
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
collections: List of graph collections keys to add the Variable to.
Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).
Defaults to `[GraphKeys.VARIABLES]` (see tf.Variable).
Returns:
The created or existing variable.
......@@ -225,9 +225,9 @@ def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
initializer: initializer for the variable if one is created.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable).
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
collections: List of graph collections keys to add the Variable to.
Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).
Defaults to `[GraphKeys.VARIABLES]` (see tf.Variable).
Returns:
The created or existing variable.
......
......@@ -42,7 +42,7 @@ class Optimizer(object):
# Create an optimizer with the desired parameters.
opt = GradientDescentOptimizer(learning_rate=0.1)
# Add Ops to the graph to minimize a cost by updating a list of variables.
# "cost" is a Tensor, and the list of variables contains variables.Variable
# "cost" is a Tensor, and the list of variables contains tf.Variable
# objects.
opt_op = opt.minimize(cost, <list of variables>)
```
......@@ -199,7 +199,7 @@ class Optimizer(object):
Args:
loss: A Tensor containing the value to minimize.
var_list: Optional list of variables.Variable to update to minimize
var_list: Optional list of tf.Variable to update to minimize
`loss`. Defaults to the list of variables collected in the graph
under the key `GraphKey.TRAINABLE_VARIABLES`.
gate_gradients: How to gate the computation of gradients. Can be
......@@ -224,7 +224,7 @@ class Optimizer(object):
var_list = variables.trainable_variables()
for var in var_list:
if not isinstance(var, variables.Variable):
raise TypeError("Argument is not a variables.Variable: %s" % var)
raise TypeError("Argument is not a tf.Variable: %s" % var)
if not var_list:
raise ValueError("No variables to optimize")
grads = gradients.gradients(
......@@ -268,13 +268,13 @@ class Optimizer(object):
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
if not isinstance(v, variables.Variable):
raise TypeError(
"Variable must be a variables.Variable: %s" % v)
"Variable must be a tf.Variable: %s" % v)
if g is not None:
self._assert_valid_dtypes([g, v])
var_list = [v for g, v in grads_and_vars if g is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s" %
grads_and_vars)
(grads_and_vars,))
self._create_slots(var_list)
update_ops = []
with ops.op_scope([], name, self._name) as name:
......@@ -339,7 +339,7 @@ class Optimizer(object):
dtype = t.dtype.base_dtype
if dtype not in valid_dtypes:
raise ValueError(
"Invalid type %s for %s, expected: %s." % (
"Invalid type %r for %s, expected: %s." % (
dtype, t.name, [v for v in valid_dtypes]))
# --------------
......
......@@ -1140,7 +1140,7 @@ class BlasSupport {
// Macro used to quickly declare overrides for abstract virtuals in the
// BlasSupport base class.
#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
bool DoBlasAsum(Stream *stream, uint64 elem_count, \
const DeviceMemory<float> &x, int incx, \
DeviceMemory<float> *result) override; \
......
......@@ -22,7 +22,8 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
#include "tensorflow/stream_executor/cuda/cuda_platform.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/initialize.h"
......
......@@ -29,16 +29,17 @@ limitations under the License.
#include <memory>
#include <vector>
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/numbers.h"
#include "tensorflow/stream_executor/lib/process_state.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/str_util.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/lib/numbers.h"
#include "tensorflow/stream_executor/lib/str_util.h"
#include "tensorflow/stream_executor/lib/inlined_vector.h"
namespace perftools {
namespace gputools {
......@@ -113,7 +114,6 @@ void Diagnostician::LogDiagnosticInformation() {
LOG(INFO) << "retrieving CUDA diagnostic information for host: "
<< port::Hostname();
LogDriverVersionInformation();
}
......
......@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_
#include "tensorflow/stream_executor/platform/port.h"
#include <tuple>
#include "tensorflow/stream_executor/lib/statusor.h"
......
......@@ -50,7 +50,8 @@ class CudnnSupport : public dnn::DnnSupport {
const DeviceMemory<float>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<float>* output_data) override;
DeviceMemory<float>* output_data,
ScratchAllocator* scratch_allocator) override;
bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
const DeviceMemory<double>& input_data,
......@@ -80,7 +81,8 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<float> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& input_descriptor,
DeviceMemory<float>* backward_input_data) override;
DeviceMemory<float>* backward_input_data,
ScratchAllocator* scratch_allocator) override;
bool DoConvolveBackwardFilter(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
......@@ -89,7 +91,8 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<float> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
DeviceMemory<float>* backward_filter_data) override;
DeviceMemory<float>* backward_filter_data,
ScratchAllocator* scratch_allocator) override;
bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
const DeviceMemory<float>& weights,
......@@ -160,20 +163,24 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) override;
bool DoMemcpyD2HQuantized(Stream* stream,
const DeviceMemory<float>& device_unquantized_src,
port::MutableArraySlice<uint8> host_dst) override;
bool DoXYPad(Stream* stream, const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data,
int64 left_pad, int64 right_pad, int64 top_pad,
int64 bottom_pad, DeviceMemory<float> *output_data) override;
bool DoMemcpyD2HQuantized(Stream* stream,
const DeviceMemory<float>& device_unquantized_src,
port::MutableArraySlice<uint16> host_dst) override;
bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data,
int64 left_trim, int64 right_trim, int64 top_trim,
int64 bottom_trim, DeviceMemory<float> *output_data) override;
bool DoMemcpyD2HQuantized(Stream* stream,
const DeviceMemory<float>& device_unquantized_src,
port::MutableArraySlice<int32> host_dst) override;
dnn::QuantizedActivationMode mode, void* host_dst,
int64 size) override;
bool DoMemcpyH2DQuantized(
Stream* stream, port::ArraySlice<uint8> host_src,
Stream* stream, const void* host_src, int64 size,
dnn::QuantizedActivationMode mode,
DeviceMemory<float>* device_unquantized_dst) override;
// Derives an output batch descriptor from an input batch and convolution
......
......@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_event.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/lib/statusor.h"
......
......@@ -22,7 +22,8 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
#include "tensorflow/stream_executor/cuda/cuda_platform.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/initialize.h"
......
......@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
#include "tensorflow/stream_executor/cuda/cuda_event.h"
#include "tensorflow/stream_executor/cuda/cuda_platform.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/cuda/cuda_timer.h"
#include "tensorflow/stream_executor/dso_loader.h"
......@@ -88,20 +88,6 @@ static CUDAEvent *AsCUDAEvent(Event *event) {
return static_cast<CUDAEvent *>(event->implementation());
}
// Given a platform-independent stream datatype, returns the internal CUDA
// platform implementation pointer.
static CUDAStream *AsCUDAStream(Stream *stream) {
DCHECK(stream != nullptr);
return static_cast<CUDAStream *>(stream->implementation());
}
// Given a platform-independent stream datatype, returns the platform
// implementation's internal value, suitable for passing directly to libcuda
// APIs.
CUstream AsCUDAStreamValue(Stream *stream) {
DCHECK(stream != nullptr);
return AsCUDAStream(stream)->cuda_stream();
}
// Given a platform-independent timer datatype, returns the internal CUDA
// platform implementation pointer.
......@@ -861,6 +847,26 @@ bool CUDAExecutor::SupportsFft() const { return true; }
bool CUDAExecutor::SupportsRng() const { return true; }
std::unique_ptr<internal::EventInterface>
CUDAExecutor::CreateEventImplementation() {
return std::unique_ptr<internal::EventInterface>(new CUDAEvent(this));
}
std::unique_ptr<internal::KernelInterface>
CUDAExecutor::CreateKernelImplementation() {
return std::unique_ptr<internal::KernelInterface>(new CUDAKernel());
}
std::unique_ptr<internal::StreamInterface>
CUDAExecutor::GetStreamImplementation() {
return std::unique_ptr<internal::StreamInterface>(new CUDAStream(this));
}
std::unique_ptr<internal::TimerInterface>
CUDAExecutor::GetTimerImplementation() {
return std::unique_ptr<internal::TimerInterface>(new CUDATimer(this));
}
void *CUDAExecutor::CudaContextHack() { return context_; }
CUcontext CUDAExecutor::cuda_context() { return context_; }
......@@ -1064,30 +1070,6 @@ void initialize_cuda_gpu_executor() {
const gpu::PluginConfig &config) {
return new gpu::cuda::CUDAExecutor{config};
};
*gpu::internal::MakeCUDAKernelImplementation() = []() {
return new gpu::cuda::CUDAKernel;
};
*gpu::internal::MakeCUDAEventImplementation() = [](
gpu::StreamExecutor *parent) {
gpu::cuda::CUDAExecutor *cuda_executor =
static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation());
return new gpu::cuda::CUDAEvent{cuda_executor};
};
*gpu::internal::MakeCUDAStreamImplementation() = [](
gpu::StreamExecutor *parent) {
gpu::cuda::CUDAExecutor *cuda_executor =
static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation());
return new gpu::cuda::CUDAStream{cuda_executor};
};
*gpu::internal::MakeCUDATimerImplementation() = [](
gpu::StreamExecutor *parent) {
gpu::cuda::CUDAExecutor *cuda_executor =
static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation());
return new gpu::cuda::CUDATimer{cuda_executor};
};
}
} // namespace gputools
......
......@@ -203,6 +203,16 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
dnn::DnnSupport *CreateDnn() override;
std::unique_ptr<internal::EventInterface> CreateEventImplementation()
override;
std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
override;
std::unique_ptr<internal::StreamInterface> GetStreamImplementation() override;
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override;
void *CudaContextHack() override;
CUcontext cuda_context();
......
......@@ -30,7 +30,6 @@ limitations under the License.
namespace perftools {
namespace gputools {
class Stream;
template <typename ElemT>
class DeviceMemory;
......@@ -51,8 +50,6 @@ T *CUDAMemoryMutable(DeviceMemory<T> *mem) {
return static_cast<T *>(mem->opaque());
}
CUstream AsCUDAStreamValue(Stream *stream);
static_assert(sizeof(std::complex<float>) == sizeof(cuComplex),
"std::complex<float> and cuComplex should have the same size");
static_assert(offsetof(cuComplex, x) == 0,
......
......@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_platform.h"
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/ptr_util.h"
......@@ -26,8 +28,6 @@ namespace perftools {
namespace gputools {
namespace cuda {
PLATFORM_DEFINE_ID(kCudaPlatformId);
CudaPlatform::CudaPlatform()
: name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {}
......@@ -147,8 +147,8 @@ port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
auto executor = port::MakeUnique<StreamExecutor>(PlatformKind::kCuda,
config.plugin_config);
auto executor = port::MakeUnique<StreamExecutor>(
this, new CUDAExecutor(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
......
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
namespace perftools {
namespace gputools {
namespace cuda {
PLATFORM_DEFINE_ID(kCudaPlatformId);
} // namespace cuda
} // namespace gputools
} // namespace perftools
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_
#include "tensorflow/stream_executor/platform.h"
namespace perftools {
namespace gputools {
namespace cuda {
// Opaque and unique identifier for the cuda platform.
// This is needed so that plugins can refer to/identify this platform without
// instantiating a CudaPlatform object.
// This is broken out here to avoid a circular dependency between CudaPlatform
// and CudaExecutor.
extern const Platform::Id kCudaPlatformId;
} // namespace cuda
} // namespace gputools
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_
......@@ -20,7 +20,8 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
#include "tensorflow/stream_executor/cuda/cuda_platform.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/initialize.h"
......
......@@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/stream.h"
namespace perftools {
namespace gputools {
......@@ -61,6 +63,16 @@ bool CUDAStream::GetOrCreateCompletedEvent(CUevent *completed_event) {
return true;
}
CUDAStream *AsCUDAStream(Stream *stream) {
DCHECK(stream != nullptr);
return static_cast<CUDAStream *>(stream->implementation());
}
CUstream AsCUDAStreamValue(Stream *stream) {
DCHECK(stream != nullptr);
return AsCUDAStream(stream)->cuda_stream();
}
} // namespace cuda
} // namespace gputools
} // namespace perftools
......@@ -20,7 +20,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools {
......@@ -82,6 +82,13 @@ class CUDAStream : public internal::StreamInterface {
CUevent completed_event_ GUARDED_BY(mu_);
};
// Helper functions to simplify extremely common flows.
// Converts a Stream to the underlying CUDAStream implementation.
CUDAStream *AsCUDAStream(Stream *stream);
// Extracts a CUstream from a CUDAStream-backed Stream object.
CUstream AsCUDAStreamValue(Stream *stream);
} // namespace cuda
} // namespace gputools
} // namespace perftools
......
......@@ -111,6 +111,7 @@ class DeviceMemory final : public DeviceMemoryBase {
public:
// Default constructor instantiates a null-pointed, zero-sized memory region.
DeviceMemory() : DeviceMemoryBase(nullptr, 0) {}
DeviceMemory(std::nullptr_t) : DeviceMemory() {}
// Typed device memory regions may be constructed from untyped device memory
// regions, this effectively amounts to a cast from a void*.
......
......@@ -22,6 +22,20 @@ namespace perftools {
namespace gputools {
namespace dnn {
string QuantizedActivationModeString(QuantizedActivationMode mode) {
switch (mode) {
case dnn::QuantizedActivationMode::k8Bit:
return "uint8";
case dnn::QuantizedActivationMode::k16Bit:
return "uint16";
case dnn::QuantizedActivationMode::k32Bit:
return "int32";
default:
LOG(FATAL) << "Unknown quantized_activation_mode "
<< static_cast<int32>(mode);
}
}
string ActivationModeString(ActivationMode mode) {
switch (mode) {
case ActivationMode::kSigmoid:
......@@ -78,6 +92,17 @@ string FilterLayoutString(FilterLayout layout) {
}
}
string ShortPoolingModeString(PoolingMode mode) {
switch (mode) {
case PoolingMode::kMaximum:
return "Max";
case PoolingMode::kAverage:
return "Avg";
default:
LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(mode);
}
}
// -- BatchDescriptor
BatchDescriptor::BatchDescriptor()
......@@ -137,7 +162,6 @@ string BatchDescriptor::ToShortString() const {
return port::StrCat(batch, depth, y, x, suffix);
default:
LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
return ""; // Avoid lack-of-return warning
}
}
......@@ -160,6 +184,20 @@ int64 BatchDescriptor::FullyConnectedBiasCount(const BatchDescriptor& output) {
return output.NodesAcrossFeatureMaps();
}
BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
port::ArraySlice<dnn::BatchDescriptor> inputs) {
if (inputs.empty()) {
return BatchDescriptor();
}
int feature_map_count = 0;
for (const auto& dimensions : inputs) {
feature_map_count += dimensions.feature_map_count();
}
BatchDescriptor output = inputs[0];
output.set_feature_map_count(feature_map_count);
return output;
}
// -- FilterDescriptor
FilterDescriptor::FilterDescriptor()
......@@ -205,7 +243,6 @@ string FilterDescriptor::ToShortString() const {
return port::StrCat(y, x, id, od);
default:
LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout_);
return ""; // Avoid lack-of-return warning
}
}
......
......@@ -32,6 +32,7 @@ namespace perftools {
namespace gputools {
class Stream;
class ScratchAllocator;
namespace dnn {
......@@ -55,6 +56,9 @@ enum class QuantizedActivationMode {
k32Bit = 4,
};
// Returns a string representation of the given quantization mode.
string QuantizedActivationModeString(QuantizedActivationMode mode);
// Describes the dimensions that a layer consumes/produces.
//
// This is a matrix (height, width), its "depth" (feature_map_count),
......@@ -175,6 +179,13 @@ class BatchDescriptor {
// with dimensions given the 'output' descriptor.
static int64 FullyConnectedBiasCount(const BatchDescriptor& output);
// Return a BatchDescriptor for the output of a depth concatenation
// with the given input descriptors. The inputs should have the same
// dimensions, except possibly for feature_map_count(), though this
// function does not verify that.
static BatchDescriptor DepthConcatenateOutputDescriptor(
port::ArraySlice<dnn::BatchDescriptor> inputs);
private:
int64 count_;
int64 feature_map_count_;
......@@ -280,8 +291,6 @@ class FilterDescriptor {
int64 input_filter_height_;
int64 input_filter_width_;
FilterLayout layout_;
SE_DISALLOW_COPY_AND_ASSIGN(FilterDescriptor);
};
// Describes a convolution.
......@@ -356,6 +365,9 @@ enum class PoolingMode : int64 {
kAverage,
};
// Returns a short name for the pooling mode, e.g. "Avg".
string ShortPoolingModeString(PoolingMode mode);
// Describes a pooling operation to be enqueued onto a stream via a platform's
// DnnSupport.
//
......@@ -423,18 +435,31 @@ class PoolingDescriptor {
int64 horizontal_padding_;
int64 vertical_stride_;
int64 horizontal_stride_;
SE_DISALLOW_COPY_AND_ASSIGN(PoolingDescriptor);
};
// Describes a dist_belief local response normalization.
// The normalization equation is:
// y_i = x_i / (bias + alpha * (sum_j_{i - range}^{i + range} x_j^2)) ^ beta
// where x_i is the input in feature map i, y_i is the output.
// Each feature map is split into segment_size segments for performing the
// sum_j_. If wrap_around is true, the sum_j_ for y_i on the left and right of
// a segment wrap around at the edges of the segment, if wrap_around is false
// zeros are inserted instead.
// Describes a local response normalization (LRN). LRN is used e.g. in
// dist_belief.
//
// Let V be the vector of feature maps at some (batch, y, x)
// coordinate. LRN applies independently to each vector V in the
// input, across all coordinates (batch, y, x), by mapping each V to
// another vector U of the same size using the formula
//
// V_i = U_i / ((bias + alpha * (sum_j U_j^2)) ^ beta)
//
// where the sum is taken for j in the inclusive range [i - range, i + range].
//
// When calculating V_i the j in the sum can extend beyond the bounds
// of U. If wrap_around is true, then U_j = U_{j mod F} where F is the
// size of U, which is the number of feature maps. If wrap_around is
// false, then U_j = 0 for j outside [0, F-1].
//
// If segment_size <= F, where F is the number of feature_maps, then
// segment_size has no effect. Otherwise, each consecutive segment of
// segment_size entries in V are normalized separately.
//
// Not all StreamExecutors allow wrap_around == true or segment_size
// != 64. Some do not implement normalization at all.
class NormalizeDescriptor {
public:
NormalizeDescriptor();
......@@ -488,8 +513,6 @@ class NormalizeDescriptor {
float beta_;
bool wrap_around_;
int32 segment_size_;
SE_DISALLOW_COPY_AND_ASSIGN(NormalizeDescriptor);
};
// Describes a kind of non-linearity (threshold-like mathematical function).
......@@ -503,6 +526,8 @@ enum class ActivationMode {
// BatchDescriptor::value_max().
kReluX,
kTanh,
// Like ReluX, but passes all values in the range [-X,X].
kBandPass,
};
// Returns a string representation of the given activation mode.
......@@ -510,10 +535,7 @@ string ActivationModeString(ActivationMode mode);
// Describes the operation that DoElementwiseOperation should perform on its
// inputs.
enum class ElementwiseOperation {
kAdd,
kMultiply
};
enum class ElementwiseOperation { kAdd, kMultiply };
string ElementwiseOperationString(ElementwiseOperation op);
......@@ -541,6 +563,8 @@ class DnnSupport {
// output_descriptor: dimensions of the output layer.
// output_data: un-owned device memory region in which to place the
// convolution result.
// scratch_allocator: un-owned, may-be-null object that may allocate scratch
// space in order to speed up the convolution operation.
//
// input_descriptor, filter_descriptor, convolution_descriptor and
// output_descriptor together specify exactly how the convolution is aligned
......@@ -564,7 +588,8 @@ class DnnSupport {
const DeviceMemory<float>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<float>* output_data) = 0;
DeviceMemory<float>* output_data,
ScratchAllocator* scratch_allocator) = 0;
// Enqueues a double-precision convolution operation onto the stream.
// See DoConvolve above for argument details.
......@@ -612,6 +637,8 @@ class DnnSupport {
// input_descriptor: dimensions of the input layer.
// backward_input_data: un-owned device memory region in which to place the
// backprop of the input.
// scratch_allocator: un-owned, may-be-null object that may allocate scratch
// space in order to speed up the convolution operation.
virtual bool DoConvolveBackwardData(
Stream* stream, const FilterDescriptor& filter_descriptor,
const DeviceMemory<float>& filter_data,
......@@ -619,7 +646,8 @@ class DnnSupport {
DeviceMemory<float> backward_output_data,
const ConvolutionDescriptor& convolution_descriptor,
const BatchDescriptor& input_descriptor,
DeviceMemory<float>* backward_input_data) = 0;
DeviceMemory<float>* backward_input_data,
ScratchAllocator* scratch_allocator) = 0;
// Enqueues a single-precision backward convolution (for filter) operation
// onto
......@@ -640,6 +668,8 @@ class DnnSupport {
// filter_descriptor: dimensions of the convolution filter.
// backward_filter_data: un-owned device memory region in which to place the
// backprop of the filter.
// scratch_allocator: un-owned, may-be-null object that may allocate scratch
// space in order to speed up the convolution operation.
virtual bool DoConvolveBackwardFilter(
Stream* stream, const BatchDescriptor& input_descriptor,
const DeviceMemory<float>& input_data,
......@@ -647,7 +677,8 @@ class DnnSupport {
DeviceMemory<float> backward_output_data,
const ConvolutionDescriptor& convolution_descriptor,
const FilterDescriptor& filter_descriptor,
DeviceMemory<float>* backward_filter_data) = 0;
DeviceMemory<float>* backward_filter_data,
ScratchAllocator* scratch_allocator) = 0;
// Fully connects the "nodes" (float values) in input_data with
// shape input_dimensions to output_data with output_dimensions
......@@ -784,8 +815,10 @@ class DnnSupport {
const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data) = 0;
// Applies local response normalization to all of the values
// held on the device in 'input_data'.
// Applies local response normalization to the values from
// input_data and writes the result to output_data. See comments on
// NormalizeDescriptor for a description of local response
// normalization.
virtual bool DoNormalize(Stream* stream,
const dnn::NormalizeDescriptor& normalize_descriptor,
const DeviceMemory<float>& input_data,
......@@ -850,6 +883,46 @@ class DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) = 0;
// Pads the input with zeros in the X and Y dimensions. The feature_map
// dimension is unchanged.
//
// Arguments (all borrowed):
// stream: borrowed pointer to the stream that the 'elementwise operation'
// should be enqueued onto.
// dimensions: The dimensions of the input.
// input_data: un-owned device memory region which contains the
// input data for the input layer.
// left_pad: Amount to pad the input on the left.
// right_pad: Amount to pad the input on the right.
// top_pad: Amount to pad the input at the top (low Y).
// bottom_pad: Amount to pad the input at the bottom (high Y).
// output_data: un-owned device memory region in which to place the
// padded result.
virtual bool DoXYPad(Stream* stream, const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data,
int64 left_pad, int64 right_pad, int64 top_pad,
int64 bottom_pad, DeviceMemory<float> *output_data) = 0;
// Extracts a slice of the input in the X and Y dimensions. The feature_map
// dimension is unchanged.
//
// Arguments (all borrowed):
// stream: borrowed pointer to the stream that the 'elementwise operation'
// should be enqueued onto.
// dimensions: The dimensions of the input.
// input_data: un-owned device memory region which contains the
// input data for the input layer.
// left_trim: Amount to cut off the input on the left.
// right_trim: Amount to cut off the input on the right.
// top_trim: Amount to cut off the input at the top (low y).
// bottom_trim: Amount to cut off the input at the bottom (high Y).
// output_data: un-owned device memory region in which to place the
// padded result.
virtual bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data,
int64 left_trim, int64 right_trim, int64 top_trim,
int64 bottom_trim, DeviceMemory<float> *output_data) = 0;
// Enqueues an asynchronous memcpy of the *quantized* output of a layer (that
// is, bytes instead of scaled floats) into 'host_dst' if they are available
// for the underlying DNN implementation. If this quantized output is not
......@@ -862,23 +935,14 @@ class DnnSupport {
// gpu_unquantized_src: the device memory that contains the unquantized data
// -- this data should also have a corresponding quantized representation
// on the device for this operation to succeed.
// mode: Type of quantization of the data to write into host_dst.
// host_dst: un-owned host memory region that is mutated in place,
// it is clobbered by the values in 'gpu_unquantized_src' when the enqueued
// (asynchronous) memcpy operation is performed.
// TODO(wgulland) Merge all these versions of DoMemcpyD2HQuantized.
virtual bool DoMemcpyD2HQuantized(
Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
port::MutableArraySlice<uint8> host_dst) = 0;
// As above, but for 16-bit values.
virtual bool DoMemcpyD2HQuantized(
Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
port::MutableArraySlice<uint16> host_dst) = 0;
// As above, but for signed 32-bit values.
// size: size in bytes of the host_dst host memory region.
virtual bool DoMemcpyD2HQuantized(
Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
port::MutableArraySlice<int32> host_dst) = 0;
QuantizedActivationMode mode, void* host_dst, int64 size) = 0;
// Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input
// of a layer (that is, bytes instead of scaled floats) if they are supported
......@@ -890,13 +954,16 @@ class DnnSupport {
// stream: borrowed pointer to the stream that the 'quantized memcpy'
// operation should be enqueued onto.
// host_src: un-owned host memory region that contains the quantized data.
// size: size in bytes of the host_src host memory region.
// mode: Type of quantization of the data to read from host_src.
// gpu_unquantized_dst: the device memory that is clobbered by the values in
// 'host_src' when the enqueued (asynchronous) memcpy operation is
// performed. -- this data should also have a corresponding quantized
// representation on the device for this operation to
// succeed.
virtual bool DoMemcpyH2DQuantized(
Stream* stream, port::ArraySlice<uint8> host_src,
Stream* stream, const void* host_src, int64 size,
QuantizedActivationMode mode,
DeviceMemory<float>* gpu_unquantized_dst) = 0;
private:
......
......@@ -42,11 +42,12 @@ namespace internal {
}
/* static */ port::Status DsoLoader::GetCudnnDsoHandle(void** dso_handle) {
// libcudnn is versioned differently than the other libraries. See b/22397368
// for some details about the complications surrounding this.
return GetDsoHandle(FindDsoPath("libcudnn.so.6.5",
"third_party/gpus/cuda/lib64"),
dso_handle);
// libcudnn is versioned differently than the other libraries and may have a
// different version number than other CUDA libraries. See b/22397368 for
// some details about the complications surrounding this.
return GetDsoHandle(
FindDsoPath("libcudnn.so.6.5", "third_party/gpus/cuda/lib64"),
dso_handle);
}
/* static */ port::Status DsoLoader::GetCufftDsoHandle(void** dso_handle) {
......@@ -89,16 +90,16 @@ namespace internal {
string path_string = path.ToString();
*dso_handle = dlopen(path_string.c_str(), dynload_flags);
if (*dso_handle == nullptr) {
LOG(INFO) << "LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH");
LOG(INFO) << "Couldn't open CUDA library " << path
<< ". LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH");
// TODO(b/22689637): Eliminate unnecessary ToString once StrCat has been
// moved to the open-sourceable version.
return port::Status(
port::error::FAILED_PRECONDITION,
port::StrCat("could not dlopen DSO: ", path, "; dlerror: ", dlerror()));
}
VLOG(2) << "loaded path \"" << path << "\" "
<< (load_kind == LoadKind::kLocal ? "locally" : "globally");
LOG(INFO) << "successfully opened CUDA library " << path
<< (load_kind == LoadKind::kLocal ? " locally" : " globally");
return port::Status::OK();
}
......
......@@ -22,21 +22,10 @@ limitations under the License.
namespace perftools {
namespace gputools {
internal::EventInterface* CreateEventImplementation(
StreamExecutor* stream_exec) {
PlatformKind platform_kind = stream_exec->platform_kind();
switch (platform_kind) {
case PlatformKind::kCuda:
return (*internal::MakeCUDAEventImplementation())(stream_exec);
default:
LOG(FATAL) << "Cannot create event implementation for platform kind: "
<< PlatformKindString(platform_kind);
}
}
Event::Event(StreamExecutor* stream_exec)
: implementation_(CreateEventImplementation(stream_exec)),
stream_exec_(stream_exec) {}
: stream_exec_(stream_exec),
implementation_(
stream_exec_->implementation()->CreateEventImplementation()) {}
Event::~Event() {
auto status = stream_exec_->DeallocateEvent(this);
......
......@@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
#include "tensorflow/stream_executor/platform/port.h"
namespace perftools {
namespace gputools {
......@@ -63,13 +65,15 @@ class Event {
private:
friend class Stream;
// Pointer to the StreamExecutor interface used to create this object.
// Not owned.
StreamExecutor* stream_exec_;
// Pointer to the platform-specific EventInterface implementation underlying
// the object. Owned.
std::unique_ptr<internal::EventInterface> implementation_;
// Pointer to the StreamExecutor interface used to create this object.
// Not owned.
StreamExecutor* stream_exec_;
SE_DISALLOW_COPY_AND_ASSIGN(Event);
};
} // namespace gputools
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -60,4 +60,5 @@ typename MakeUniqueResult<T>::invalid MakeUnique(Args&&... /* args */) =
} // namespace gputools
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册