diff --git a/tensorflow/core/common_runtime/data/BUILD b/tensorflow/core/common_runtime/data/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2544cc67af695b823c0d8db7120020f0f51b0dc1 --- /dev/null +++ b/tensorflow/core/common_runtime/data/BUILD @@ -0,0 +1,29 @@ +package( + licenses = ["notice"], # Apache 2.0 +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") + +cc_library( + name = "standalone", + srcs = ["standalone.cc"], + hdrs = ["standalone.h"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:session_options", + ], +) + +tf_cc_test( + name = "standalone_test", + srcs = ["standalone_test.cc"], + deps = [ + ":standalone", + "//tensorflow/core:all_kernels", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + tf_protos_all(), +) diff --git a/tensorflow/core/common_runtime/data/standalone.cc b/tensorflow/core/common_runtime/data/standalone.cc new file mode 100644 index 0000000000000000000000000000000000000000..eebf00096a075042a29e4a006266f0e5be39095a --- /dev/null +++ b/tensorflow/core/common_runtime/data/standalone.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/data/standalone.h" + +#include + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace data { +namespace standalone { + +Status Iterator::GetNext(std::vector* outputs, bool* end_of_input) { + return iterator_->GetNext(ctx_.get(), outputs, end_of_input); +} + +Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx) + : iterator_(iterator), ctx_(ctx) {} + +Status Dataset::FromGraph(Params params, const GraphDef& graph_def, + std::unique_ptr* result) { + Graph graph(OpRegistry::Global()); + TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); + + // Instantiate enough of the TensorFlow runtime to run `graph` on a single CPU + // device. + std::unique_ptr device_mgr = + MakeUnique(DeviceFactory::NewDevice( + "CPU", params.session_options, "/job:localhost/replica:0/task:0")); + Device* device = device_mgr->ListDevices()[0]; + // Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond + // the lifetime of `graph`. + std::unique_ptr flib_def = + MakeUnique(graph.flib_def()); + std::unique_ptr pflr = + MakeUnique( + device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, + flib_def.get(), OptimizerOptions{}, nullptr /* parent */); + + string fetch_node = ""; + for (auto node : graph_def.node()) { + if (node.op() == "_Retval") { + fetch_node = node.input(0); + } + } + if (fetch_node.empty()) { + return errors::NotFound("Failed to find a _Retval op in the given dataset"); + } + + // Run graph up to `output_node` and extract the `DatasetBase` stored in the + // DT_VARIANT output tensor. + data::DatasetBase* dataset; + { + std::vector outputs; + GraphRunner graph_runner(device); + TF_RETURN_IF_ERROR(graph_runner.Run(&graph, pflr->GetFLR("/device:CPU:0"), + {}, {fetch_node}, &outputs)); + TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); + // NOTE(mrry): The dataset is currently owned by `outputs[0]`, so acquire an + // additional reference. + dataset->Ref(); + } + + std::unique_ptr pool( + NewThreadPoolFromSessionOptions(params.session_options)); + *result = + WrapUnique(new Dataset(dataset, device_mgr.release(), pflr.release(), + flib_def.release(), pool.release())); + return Status::OK(); +} // static + +Status Dataset::MakeIterator(std::unique_ptr* result) { + // Create an `IteratorContext`, which bundles together the necessary runtime + // support to create and get elements from an iterator. + std::unique_ptr ctx; + { + // NOTE(mrry): In the current API, an `IteratorContext` is always initially + // created from an `OpKernelContext*`, so we need to create a fake + // `OpKernelContext` with the appropriate subset of parameters. + OpKernelContext::Params op_params; + op_params.function_library = pflr_->GetFLR("/device:CPU:0"); + op_params.device = device_mgr_->ListDevices()[0]; + op_params.runner = &runner_; + OpKernelContext op_ctx(&op_params, 0); + IteratorContext::Params params(&op_ctx); + params.function_handle_cache = function_handle_cache_.get(); + ctx = MakeUnique(std::move(params)); + } + + // Create the iterator from the dataset. + std::unique_ptr iterator; + TF_RETURN_IF_ERROR(dataset_->MakeIterator(ctx.get(), "iterator", &iterator)); + + *result = WrapUnique(new Iterator(iterator.release(), ctx.release())); + + return Status::OK(); +} + +Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr, + ProcessFunctionLibraryRuntime* pflr, + FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool) + : dataset_(dataset), + device_mgr_(device_mgr), + flib_def_(flib_def), + pflr_(pflr), + pool_(pool) { + runner_ = [this](std::function c) { pool_->Schedule(std::move(c)); }; + function_handle_cache_ = + MakeUnique(pflr_->GetFLR("/device:CPU:0")); +} + +Dataset::~Dataset() { dataset_->Unref(); } + +} // namespace standalone +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/data/standalone.h b/tensorflow/core/common_runtime/data/standalone.h new file mode 100644 index 0000000000000000000000000000000000000000..7ec420ab8acb1a71cf6001039fe53c378421e519 --- /dev/null +++ b/tensorflow/core/common_runtime/data/standalone.h @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DATA_STANDALONE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DATA_STANDALONE_H_ + +#include + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function_handle_cache.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace data { +namespace standalone { + +// The purpose of the API in this file is to facilitate standalone execution of +// a tf.data input pipeline graph. +// +// The API exposes two abstractions -- a `Dataset` and an `Iterator` -- which +// encapsulate TensorFlow runtime. +// +// The `Dataset` abstraction represents an input pipeline as a collection +// of data sources and a logical plan of transformations that operate over the +// data. +// +// The `Iterator` abstraction represents an execution of an input pipeline that +// can be used to enumerate its elements. +// +// Example usage: +// +// // Create a `Dataset` by running the `graph_def` graph. +// tensorflow::data:standalone::Dataset::Params params; +// std::unique_ptr dataset; +// Status s = tensorflow::data::standalone::Dataset::FromGraph( +// params, graph_def, &dataset); +// if (!s.ok()) { /* error handling */ } +// +// std::unique_ptr iterator; +// s = dataset->MakeIterator(&iterator); +// if (!s.ok()) { /* error handling */ } +// +// bool end_of_input = false; +// while (!end_of_input) { +// std::vector outputs; +// s = iterator->GetNext(&outputs, &end_of_input); +// if (!s.ok()) { /* error handling */ } +// if (!end_of_input) { /* output handling */ } +// } + +class Dataset; + +// Represents an execution of an input pipeline that can be used to enumerate +// its elements. +class Iterator { + public: + // Returns the next element of the input pipeline (if there is one) and an + // indication of whether the end of the input pipeline has been reached. + Status GetNext(std::vector* outputs, bool* end_of_input); + + private: + friend class Dataset; + + Iterator(IteratorBase* iterator, IteratorContext* ctx); + + std::unique_ptr iterator_; + std::unique_ptr ctx_; +}; + +// Represents an input pipeline as a collection of data sources and a logical +// plan of transformations that operate over the data. +class Dataset { + public: + // Parameters for `Dataset` creation (e.g. TensorFlow runtime configuration). + struct Params { + SessionOptions session_options; + }; + + // Creates a new `Dataset` instance by running the given dataset graph. + static Status FromGraph(Params params, const GraphDef& graph_def, + std::unique_ptr* result); + + ~Dataset(); + + // Creates an iterator for this dataset. + Status MakeIterator(std::unique_ptr* result); + + private: + Dataset(DatasetBase* dataset, DeviceMgr* device_mgr, + ProcessFunctionLibraryRuntime* pflr, + FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool); + + DatasetBase* dataset_; // owned + std::unique_ptr device_mgr_; + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + std::unique_ptr pool_; + std::unique_ptr function_handle_cache_; + std::function)> runner_; +}; + +} // namespace standalone +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DATA_STANDALONE_H_ diff --git a/tensorflow/core/common_runtime/data/standalone_test.cc b/tensorflow/core/common_runtime/data/standalone_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7216b2d457ecf66a3b74835794ada0befa58648 --- /dev/null +++ b/tensorflow/core/common_runtime/data/standalone_test.cc @@ -0,0 +1,307 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/data/standalone.h" + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace standalone { +namespace { + +constexpr const char* const kRangeGraphProto = R"proto( + node { + name: "Const/_0" + op: "Const" + attr { + key: "dtype" + value { type: DT_INT64 } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape {} + int64_val: 0 + } + } + } + } + node { + name: "Const/_1" + op: "Const" + attr { + key: "dtype" + value { type: DT_INT64 } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape {} + int64_val: 10 + } + } + } + } + node { + name: "Const/_2" + op: "Const" + attr { + key: "dtype" + value { type: DT_INT64 } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape {} + int64_val: 1 + } + } + } + } + node { + name: "RangeDataset/_3" + op: "RangeDataset" + input: "Const/_0" + input: "Const/_1" + input: "Const/_2" + attr { + key: "output_shapes" + value { list { shape {} } } + } + attr { + key: "output_types" + value { list { type: DT_INT64 } } + } + } + node { + name: "dataset" + op: "_Retval" + input: "RangeDataset/_3" + attr { + key: "T" + value { type: DT_VARIANT } + } + attr { + key: "index" + value { i: 0 } + } + } + library {} + versions { producer: 96 } +)proto"; + +// range(10).map(lambda x: x*x) +constexpr const char* const kMapGraphProto = R"proto( + node { + name: "Const/_0" + op: "Const" + attr { + key: "dtype" + value { type: DT_INT64 } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape {} + int64_val: 0 + } + } + } + } + node { + name: "Const/_1" + op: "Const" + attr { + key: "dtype" + value { type: DT_INT64 } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape {} + int64_val: 10 + } + } + } + } + node { + name: "Const/_2" + op: "Const" + attr { + key: "dtype" + value { type: DT_INT64 } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape {} + int64_val: 1 + } + } + } + } + node { + name: "RangeDataset/_3" + op: "RangeDataset" + input: "Const/_0" + input: "Const/_1" + input: "Const/_2" + attr { + key: "output_shapes" + value { list { shape {} } } + } + attr { + key: "output_types" + value { list { type: DT_INT64 } } + } + } + node { + name: "MapDataset/_4" + op: "MapDataset" + input: "RangeDataset/_3" + attr { + key: "Targuments" + value { list {} } + } + attr { + key: "f" + value { func { name: "__inference_Dataset_map__67" } } + } + attr { + key: "output_shapes" + value { list { shape {} } } + } + attr { + key: "output_types" + value { list { type: DT_INT64 } } + } + attr { + key: "preserve_cardinality" + value { b: false } + } + attr { + key: "use_inter_op_parallelism" + value { b: true } + } + } + node { + name: "dataset" + op: "_Retval" + input: "MapDataset/_4" + attr { + key: "T" + value { type: DT_VARIANT } + } + attr { + key: "index" + value { i: 0 } + } + } + library { + function { + signature { + name: "__inference_Dataset_map__67" + input_arg { name: "args_0" type: DT_INT64 } + output_arg { name: "identity" type: DT_INT64 } + } + node_def { + name: "mul" + op: "Mul" + input: "args_0" + input: "args_0" + attr { + key: "T" + value { type: DT_INT64 } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "mul:z:0" + attr { + key: "T" + value { type: DT_INT64 } + } + } + ret { key: "identity" value: "Identity:output:0" } + arg_attr { + key: 0 + value { + attr { + key: "_user_specified_name" + value { s: "args_0" } + } + } + } + } + } + versions { producer: 96 min_consumer: 12 } +)proto"; + +TEST(Scalar, Standalone) { + struct TestCase { + string graph_string; + std::vector expected_outputs; + }; + auto test_cases = { + TestCase{kRangeGraphProto, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}, + TestCase{kMapGraphProto, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81}}, + }; + for (auto test_case : test_cases) { + GraphDef graph_def; + protobuf::TextFormat::ParseFromString(test_case.graph_string, &graph_def); + std::unique_ptr dataset; + auto s = Dataset::FromGraph({}, graph_def, &dataset); + TF_EXPECT_OK(s); + std::unique_ptr iterator; + s = dataset->MakeIterator(&iterator); + TF_EXPECT_OK(s); + bool end_of_input = false; + for (int num_outputs = 0; !end_of_input; ++num_outputs) { + std::vector outputs; + s = iterator->GetNext(&outputs, &end_of_input); + TF_EXPECT_OK(s); + if (!end_of_input) { + EXPECT_EQ(outputs[0].scalar()(), + test_case.expected_outputs[num_outputs]); + } else { + EXPECT_EQ(test_case.expected_outputs.size(), num_outputs); + } + } + } +} + +} // namespace +} // namespace standalone +} // namespace data +} // namespace tensorflow