提交 06fb6333 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

Automated rollback of commit 432de130

PiperOrigin-RevId: 246502904
上级 a85bbeb7
......@@ -683,6 +683,7 @@ class DatasetBase : public core::RefCounted {
protected:
friend Status AsGraphDef(
OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def); // For access to graph related members.
friend class CapturedFunction;
......
......@@ -50,7 +50,8 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
"ZipDataset"
};
constexpr std::array<const char*, 22> kPassThroughOps = {
constexpr std::array<const char*, 23> kPassThroughOps = {
"_Retval",
"BatchDataset",
"BatchDatasetV2",
"ExperimentalMapAndBatchDataset",
......@@ -285,16 +286,14 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
// function in flat_map.
if (IsDatasetNodeOfType(node, kFuncDatasetOps) &&
ReaderOpInFunction(node, *flib)) {
TF_RETURN_IF_ERROR(ProcessDatasetSourceNode(graph, node, nodes_to_delete,
num_workers, index));
return Status::OK();
return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
index);
}
if (IsDatasetNodeOfType(node, kReaderDatasetOps)) {
// We reached a reader dataset directly and we try to shard input 0.
TF_RETURN_IF_ERROR(ProcessDatasetSourceNode(graph, node, nodes_to_delete,
num_workers, index));
return Status::OK();
return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
index);
}
if (!IsDatasetNodeOfType(node, kPassThroughOps)) {
......
......@@ -301,12 +301,11 @@ Status EnsureNodeNamesUnique(Graph* g) {
return Status::OK();
}
// Tries to find a Sink node in the graph. A sink node is defined as a node
// Tries to find a "sink" node in the graph. A sink node is defined as a node
// that has at least one input and no outputs. If there are multiple of these,
// this might return any one of them. This is useful to identify the final
// Dataset op in the graph but in some cases there might be multiple Identity
// ops added to the end and this would return the last Identity op in that case.
Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node) {
absl::flat_hash_map<string, int> all_node_names;
absl::flat_hash_map<string, int> node_input_map;
......
......@@ -83,14 +83,6 @@ Status LatencyAllEdges::OptimizeAndCollectStats(Cluster* cluster,
// node corresponds to a `Dataset` op.
continue;
}
MutableGraphView::OutputPort output_port =
graph.GetOutputPort(node.name(), 0);
auto fanout = graph.GetFanout(output_port);
if (fanout.size() > 1) {
LOG(WARNING) << node.name() << " has fanout size " << fanout.size();
continue;
}
// fanout will have size 0 for last dataset node in the pipeline.
NodeDef* latency_node = graph.AddNode(MakeLatencyNode(node, &graph));
TF_RETURN_IF_ERROR(graph.UpdateFanouts(node.name(), latency_node->name()));
stats->num_changes++;
......
......@@ -41,8 +41,9 @@ Status RebatchOptimizer::Init(
namespace {
constexpr char kCastOp[] = "Cast";
constexpr char kRealDivOp[] = "RealDiv";
constexpr char kConstOp[] = "Const";
constexpr char kIdentityOp[] = "Identity";
constexpr char kRealDivOp[] = "RealDiv";
constexpr std::array<const char*, 5> kBatchDatasetOps = {
"BatchDataset",
......@@ -135,12 +136,24 @@ bool IsDatasetNodeOfType(const NodeDef& node,
return false;
}
Status UpdateOutputShapes(const string& node_name, int64 num_workers,
MutableGraphView* graph) {
NodeDef* node = graph->GetNode(node_name);
if (node->op() == kIdentityOp) {
return Status::OK();
}
AttrValue output_shapes = node->attr().at("output_shapes");
for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) {
shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_workers);
}
(*node->mutable_attr())["output_shapes"] = output_shapes;
return Status::OK();
}
// Given a "batch" dataset node, modifies the batch_size input to divide the
// current batch size by num_workers.
Status MutateBatchSize(const NodeDef& node, int64 num_workers,
MutableGraphView* graph) {
// TODO(rohanj): Fix up the output_shapes attribute as well. For this Dataset
// as well as all the downstream datasets.
// For all the batching datasets the batch_size is input number 1 except for
// MapAndBatchDataset.
int64 batch_size_arg_index = 1;
......@@ -194,7 +207,8 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
FunctionLibraryDefinition* flib,
MutableGraphView* graph) {
if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
return MutateBatchSize(node, num_workers, graph);
TF_RETURN_IF_ERROR(MutateBatchSize(node, num_workers, graph));
TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_workers, graph));
} else if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
// For all multiple input datasets, all inputs are datasets themselves.
for (int i = 0; i < node.input_size(); ++i) {
......@@ -202,12 +216,14 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
TF_RETURN_IF_ERROR(
RecursivelyHandleOp(*input_node, num_workers, flib, graph));
}
TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_workers, graph));
} else if (IsDatasetNodeOfType(node, kPassThroughOps)) {
// For all the dataset ops that are pass through, the input dataset is
// input 0.
NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
TF_RETURN_IF_ERROR(
RecursivelyHandleOp(*input_node, num_workers, flib, graph));
TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_workers, graph));
} else if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
const string func_name = node.attr().at("f").func().name();
const FunctionDef* fdef = flib->Find(func_name);
......@@ -233,6 +249,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
// Replace optimized function with a new FunctionDef.
TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, optimized_func));
TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_workers, graph));
} else {
VLOG(2) << "Failed to optimize dataset function. Error: "
<< s.error_message();
......
......@@ -51,6 +51,14 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//tensorflow/core/grappler/optimizers/data",
"//tensorflow/core/grappler/optimizers/data:function_utils",
"//tensorflow/core/grappler/optimizers/data:graph_utils",
],
)
......@@ -936,31 +944,11 @@ tf_kernel_library(
],
)
cc_library(
name = "graph_rewrite_dataset",
srcs = ["graph_rewrite_dataset.cc"],
hdrs = ["graph_rewrite_dataset.h"],
deps = [
":captured_function",
":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//tensorflow/core/grappler/optimizers/data",
"//tensorflow/core/grappler/optimizers/data:function_utils",
"//tensorflow/core/grappler/optimizers/data:graph_utils",
],
)
tf_kernel_library(
name = "optimize_dataset_op",
srcs = ["optimize_dataset_op.cc"],
deps = [
":graph_rewrite_dataset",
":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
......
......@@ -32,7 +32,8 @@ class DatasetToGraphOp : public OpKernel {
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
GraphDef graph_def;
OP_REQUIRES_OK(ctx, AsGraphDef(ctx, dataset, &graph_def));
OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, dataset, SerializationContext({}), &graph_def));
Tensor* result;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
result->scalar<string>()() = graph_def.SerializeAsString();
......
......@@ -17,20 +17,128 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace data {
namespace {
void AddFakeSinks(FunctionDef* function_def) {
int counter = 0;
for (const auto& output : function_def->signature().output_arg()) {
NodeDef* node = function_def->add_node_def();
tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
strings::StrCat("FakeSink", counter++), function_def, node);
node->set_op("Identity");
node->add_input(function_def->ret().at(output.name()));
(*node->mutable_attr())["T"].set_type(output.type());
(*function_def->mutable_ret())[output.name()] =
strings::StrCat(node->name(), ":output:0");
}
}
void RemoveFakeSinks(FunctionDef* function_def) {
// Map from identity node names to their input tensor strings
std::map<string, string> identity_map;
for (const auto& node : function_def->node_def()) {
if (node.op() == "Identity" && node.input_size() == 1) {
identity_map[node.name()] = node.input(0);
}
}
for (const auto& output_arg : function_def->signature().output_arg()) {
const string& tensor = function_def->ret().at(output_arg.name());
const string& output_node = tensor.substr(0, tensor.find(':'));
if (identity_map.find(output_node) != identity_map.end()) {
(*function_def->mutable_ret())[output_arg.name()] =
identity_map.at(output_node);
}
}
}
Status ApplyRewrites(OpKernelContext* ctx,
const std::function<RewriterConfig(void)> config_factory,
bool optimize_function_library, GraphDef* graph_def,
string* output_node) {
// Add an identity node as the fetch node, otherwise we might get 'placeholder
// is both fed and fetched' errors in some cases when using input list with
// placeholder dataset nodes.
NodeDef* node = graph_def->mutable_node()->Add();
tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
node);
node->set_op("Identity");
node->add_input(*output_node);
(*node->mutable_attr())["T"].set_type(DT_VARIANT);
*output_node = node->name();
// Add fake sink node to graph and functions to allow rewriting the actual
// sink nodes.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
AddFakeSinks(&function_def);
}
// Create metagraph.
MetaGraphDef meta_graph_def;
(*meta_graph_def.mutable_graph_def()) = *graph_def;
// Grappler determines fetch ops from collection 'train_op'.
CollectionDef collection_def;
auto node_list = collection_def.mutable_node_list();
node_list->add_value(*output_node);
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
// Create Grappler item.
tensorflow::grappler::ItemConfig item_config;
item_config.apply_optimizations = true;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
"graph", meta_graph_def, item_config);
grappler_item->optimization_options().optimize_function_library =
optimize_function_library;
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run data optimizer using grappler's meta optimizer.
tensorflow::ConfigProto config;
*config.mutable_graph_options()->mutable_rewrite_options() = config_factory();
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
*grappler_item, config, ctx->device(), &cluster, graph_def));
// Remove fake sinks after optimizations are done.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
RemoveFakeSinks(&function_def);
}
return Status::OK();
}
} // anonymous namespace
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def) {
GraphDefBuilder b;
DatasetBase::DatasetGraphDefBuilder db(&b);
Node* output_node = nullptr;
SerializationContext serialization_ctx({});
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, dataset, &output_node));
// Insert a purely symbolic _Retval node to indicate to consumers which Tensor
......@@ -44,6 +152,57 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
return Status::OK();
}
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
std::function<RewriterConfig(void)> config_factory,
bool optimize_function_library,
DatasetBase** rewritten_input) {
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.optimization_only = true;
SerializationContext serialization_ctx(params);
GraphDef graph_def;
TF_RETURN_IF_ERROR(
AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
string output_node;
for (const auto& node : graph_def.node()) {
if (node.op() == "_Retval") {
output_node = node.input(0);
}
}
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory,
optimize_function_library, &graph_def,
&output_node));
VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
// Instantiate the optimized input pipeline by running the optimized graph
// using the optimized function library.
FunctionLibraryRuntime* flr = nullptr;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
TF_RETURN_IF_ERROR(
ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
// Some functions may have been modified without having their names
// changed (for example, nested dataset graphs from FlatMap or
// Interleave).
TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
GraphRunner graph_runner(flr->device());
TF_RETURN_IF_ERROR(
graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input));
(*rewritten_input)->Ref();
return Status::OK();
}
Status VerifyTypesMatch(const DataTypeVector& expected,
const DataTypeVector& received) {
if (expected.size() != received.size()) {
......
......@@ -23,8 +23,15 @@ namespace data {
// Returns a GraphDef representation of the given dataset.
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def);
// Rewrites the input dataset using the given config.
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
std::function<RewriterConfig(void)> config_factory,
bool optimize_function_library,
DatasetBase** rewritten_input);
// Returns Status::OK() if `expected` and `received` types match,
// errors::InvalidArgument otherwise.
Status VerifyTypesMatch(const DataTypeVector& expected,
......
......@@ -10,10 +10,6 @@ load(
"//tensorflow:tensorflow.bzl",
"tf_kernel_library",
)
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library",
)
tf_kernel_library(
name = "assert_next_dataset_op",
......@@ -25,6 +21,21 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "auto_shard_dataset_op",
srcs = ["auto_shard_dataset_op.cc"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/optimizers/data:auto_shard",
"//tensorflow/core/kernels/data:dataset_utils",
],
)
tf_kernel_library(
name = "choose_fastest_branch_dataset_op",
srcs = ["choose_fastest_branch_dataset_op.cc"],
......@@ -73,21 +84,6 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "auto_shard_dataset_op",
srcs = ["auto_shard_dataset_op.cc"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/optimizers/data:auto_shard",
"//tensorflow/core/kernels/data:graph_rewrite_dataset",
],
)
tf_kernel_library(
name = "group_by_reducer_dataset_op",
srcs = ["group_by_reducer_dataset_op.cc"],
......@@ -251,7 +247,7 @@ tf_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/optimizers/data:rebatch",
"//tensorflow/core/kernels/data:graph_rewrite_dataset",
"//tensorflow/core/kernels/data:dataset_utils",
],
)
......
......@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace data {
......@@ -24,17 +25,12 @@ constexpr char kOptimizerName[] = "tf_auto_shard";
class AutoShardDatasetOp : public UnaryDatasetOpKernel {
public:
explicit AutoShardDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
: UnaryDatasetOpKernel(ctx) {}
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 index;
int64 num_workers;
int64 index, num_workers;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers));
OP_REQUIRES(
ctx, num_workers > 0,
......@@ -45,69 +41,39 @@ class AutoShardDatasetOp : public UnaryDatasetOpKernel {
errors::InvalidArgument("index must be between 0 and ",
num_workers - 1));
Dataset* dataset = new Dataset(ctx, input, num_workers, index,
output_types_, output_shapes_);
const Status s = dataset->Optimize(ctx);
if (s.ok()) {
*output = dataset;
} else {
dataset->Unref();
OP_REQUIRES_OK(ctx, s);
}
auto config_factory = [num_workers, index]() {
return CreateConfig(num_workers, index);
};
// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
OP_REQUIRES_OK(ctx,
RewriteDataset(ctx, input, std::move(config_factory),
/*optimize_function_library=*/false, output));
}
private:
class Dataset : public GraphRewriteDataset {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const int64 num_workers, const int64 index,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
num_workers_(num_workers),
index_(index) {}
string DebugString() const override {
return "AutoShardDatasetOp::Dataset";
}
private:
bool ShouldOptimizeFunctions() override {
// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
return false;
}
RewriterConfig CreateGrapplerRewriteConfig() override {
RewriterConfig rewriter_config;
rewriter_config.set_fail_on_optimizer_errors(true);
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
AttrValue num_workers_attr;
num_workers_attr.set_i(num_workers_);
(*custom_optimizer->mutable_parameter_map())["num_workers"] =
num_workers_attr;
AttrValue index_attr;
index_attr.set_i(index_);
(*custom_optimizer->mutable_parameter_map())["index"] = index_attr;
return rewriter_config;
}
const int64 num_workers_;
const int64 index_;
};
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
static RewriterConfig CreateConfig(int64 num_workers, int64 index) {
RewriterConfig rewriter_config;
rewriter_config.set_fail_on_optimizer_errors(true);
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
AttrValue num_workers_attr;
num_workers_attr.set_i(num_workers);
(*custom_optimizer->mutable_parameter_map())["num_workers"] =
num_workers_attr;
AttrValue index_attr;
index_attr.set_i(index);
(*custom_optimizer->mutable_parameter_map())["index"] = index_attr;
return rewriter_config;
}
};
REGISTER_KERNEL_BUILDER(Name("ExperimentalAutoShardDataset").Device(DEVICE_CPU),
......
......@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace data {
......@@ -24,11 +25,7 @@ constexpr char kOptimizerName[] = "tf_data_rebatcher";
class RebatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit RebatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
: UnaryDatasetOpKernel(ctx) {}
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
......@@ -39,58 +36,32 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel {
ctx, num_workers > 0,
errors::InvalidArgument("num_workers must be greater than zero."));
Dataset* dataset =
new Dataset(ctx, input, num_workers, output_types_, output_shapes_);
Status s = dataset->Optimize(ctx);
if (s.ok()) {
*output = dataset;
} else {
dataset->Unref();
OP_REQUIRES_OK(ctx, s);
}
auto config_factory = [num_workers]() { return CreateConfig(num_workers); };
// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
OP_REQUIRES_OK(ctx,
RewriteDataset(ctx, input, std::move(config_factory),
/*optimize_function_library=*/false, output));
}
private:
class Dataset : public GraphRewriteDataset {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const int64 num_workers, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
num_workers_(num_workers) {}
string DebugString() const override { return "RebatchDatasetOp::Dataset"; }
private:
bool ShouldOptimizeFunctions() override {
// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
return false;
}
RewriterConfig CreateGrapplerRewriteConfig() override {
RewriterConfig rewriter_config;
rewriter_config.set_fail_on_optimizer_errors(true);
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
AttrValue num_workers_attr;
num_workers_attr.set_i(num_workers_);
(*custom_optimizer->mutable_parameter_map())["num_workers"] =
num_workers_attr;
return rewriter_config;
}
const int64 num_workers_;
};
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
static RewriterConfig CreateConfig(int64 num_workers) {
RewriterConfig rewriter_config;
rewriter_config.set_fail_on_optimizer_errors(true);
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
AttrValue num_workers_attr;
num_workers_attr.set_i(num_workers);
(*custom_optimizer->mutable_parameter_map())["num_workers"] =
num_workers_attr;
return rewriter_config;
}
};
REGISTER_KERNEL_BUILDER(Name("ExperimentalRebatchDataset").Device(DEVICE_CPU),
......
......@@ -123,7 +123,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path));
GraphDef graph_def;
OP_REQUIRES_OK(ctx, AsGraphDef(ctx, input, &graph_def));
OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, input, SerializationContext({}), &graph_def));
// TODO(frankchn): Find a better way than SerializeToStringDeterministic()
// This is not deterministic across different builds of binaries right now.
......
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include <memory>
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace data {
GraphRewriteDataset::~GraphRewriteDataset() {
input_->Unref();
if (optimized_input_) {
optimized_input_->Unref();
}
}
Status GraphRewriteDataset::Optimize(OpKernelContext* ctx) {
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.optimization_only = true;
SerializationContext serialization_ctx(params);
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, input_, &input_node));
string output_node = input_node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
VLOG(3) << "Before optimization: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
VLOG(3) << "After optimization: " << graph_def.DebugString();
// Instantiate the optimized input pipeline by running the optimized graph
// using the optimized function library.
TF_RETURN_IF_ERROR(
ctx->function_library()->Clone(&lib_def_, &pflr_, &flr_, true));
// Create a FunctionHandleCache.
function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
// Some functions may have been modified without having their names
// changed (for example, nested dataset graphs from FlatMap or
// Interleave).
TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def_.get(), graph_def.library()));
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
GraphRunner graph_runner(flr_->device());
TF_RETURN_IF_ERROR(
graph_runner.Run(&graph, flr_, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
optimized_input_->Ref();
return Status::OK();
}
Status GraphRewriteDataset::AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const {
// We only serialize the optimized dataset to avoid re-running optimizations
// when the input pipeline is restored from a checkpoint.
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output));
return Status::OK();
}
namespace {
void AddFakeSinks(FunctionDef* function_def) {
int counter = 0;
for (const auto& output : function_def->signature().output_arg()) {
NodeDef* node = function_def->add_node_def();
tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
strings::StrCat("FakeSink", counter++), function_def, node);
node->set_op("Identity");
node->add_input(function_def->ret().at(output.name()));
(*node->mutable_attr())["T"].set_type(output.type());
(*function_def->mutable_ret())[output.name()] =
strings::StrCat(node->name(), ":output:0");
}
}
void RemoveFakeSinks(FunctionDef* function_def) {
// Map from identity node names to their input tensor strings
std::map<string, string> identity_map;
for (const auto& node : function_def->node_def()) {
if (node.op() == "Identity" && node.input_size() == 1) {
identity_map[node.name()] = node.input(0);
}
}
for (const auto& output_arg : function_def->signature().output_arg()) {
const string& tensor = function_def->ret().at(output_arg.name());
const string& output_node = tensor.substr(0, tensor.find(':'));
if (identity_map.find(output_node) != identity_map.end()) {
(*function_def->mutable_ret())[output_arg.name()] =
identity_map.at(output_node);
}
}
}
} // anonymous namespace
Status GraphRewriteDataset::ApplyOptimizations(OpKernelContext* ctx,
GraphDef* graph_def,
string* output_node) {
// Add an identity node as the fetch node, otherwise we might get 'placeholder
// is both fed and fetched' errors in some cases when using input list with
// placeholder dataset nodes.
NodeDef* node = graph_def->mutable_node()->Add();
tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
node);
node->set_op("Identity");
node->add_input(*output_node);
(*node->mutable_attr())["T"].set_type(DT_VARIANT);
*output_node = node->name();
// Add fake sink node to graph and functions to allow rewriting the actual
// sink nodes.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
AddFakeSinks(&function_def);
}
// Create metagraph.
MetaGraphDef meta_graph_def;
(*meta_graph_def.mutable_graph_def()) = *graph_def;
// Grappler determines fetch ops from collection 'train_op'.
CollectionDef collection_def;
auto node_list = collection_def.mutable_node_list();
node_list->add_value(*output_node);
(*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
// Create Grappler item.
tensorflow::grappler::ItemConfig item_config;
item_config.apply_optimizations = true;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
"graph", meta_graph_def, item_config);
grappler_item->optimization_options().optimize_function_library =
ShouldOptimizeFunctions();
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run data optimizer using grappler's meta optimizer.
tensorflow::ConfigProto config;
*config.mutable_graph_options()->mutable_rewrite_options() =
CreateGrapplerRewriteConfig();
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
*grappler_item, config, ctx->device(), &cluster, graph_def));
// Remove fake sinks after optimizations are done.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
RemoveFakeSinks(&function_def);
}
return Status::OK();
}
class GraphRewriteDataset::Iterator
: public DatasetIterator<GraphRewriteDataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<GraphRewriteDataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
IteratorContext::Params params(ctx);
params.flr = dataset()->flr_;
params.function_handle_cache = dataset()->function_handle_cache_.get();
return dataset()->optimized_input_->MakeIterator(
IteratorContext(std::move(params)), prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
IteratorContext::Params params(ctx);
params.flr = dataset()->flr_;
params.function_handle_cache = dataset()->function_handle_cache_.get();
return input_impl_->GetNext(IteratorContext(std::move(params)), out_tensors,
end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
private:
std::unique_ptr<IteratorBase> input_impl_;
};
std::unique_ptr<IteratorBase> GraphRewriteDataset::MakeIteratorInternal(
const string& prefix) const {
// We do not add a token for this dataset to the prefix. The
// prefix is used to identify checkpoint elements and since this
// dataset is excluded from the checkpoint, adding a token
// here would result in invalid checkpoint identifiers.
return absl::make_unique<Iterator>(Iterator::Params{this, prefix});
}
} // namespace data
} // namespace tensorflow
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
#define TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
namespace tensorflow {
namespace data {
class GraphRewriteDataset : public DatasetBase {
public:
GraphRewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
optimized_input_(nullptr),
input_(input),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
}
~GraphRewriteDataset() override;
// Runs Grappler to transform the input dataset into optimized_input_
// dataset.
Status Optimize(OpKernelContext* ctx);
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override;
const DataTypeVector& output_dtypes() const override { return output_types_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
int64 Cardinality() const override { return input_->Cardinality(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override;
private:
class Iterator;
// Create a Grappler RewriteConfig proto that defines the list of
// optimizations to be run by the Grappler Meta Optimizer.
virtual RewriterConfig CreateGrapplerRewriteConfig() = 0;
// Option specifying whether we want to optimize the function library as well.
virtual bool ShouldOptimizeFunctions() { return true; }
Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
string* output_node);
DatasetBase* optimized_input_;
FunctionLibraryRuntime* flr_ = nullptr;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_ = nullptr;
std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr;
std::unique_ptr<FunctionHandleCache> function_handle_cache_ = nullptr;
const DatasetBase* input_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
......@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
......@@ -32,12 +32,9 @@ constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
public:
explicit OptimizeDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("optimization_configs", &optimizer_configs_));
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(
ctx, ctx->GetAttr("optimization_configs", &optimization_configs_));
}
protected:
......@@ -46,62 +43,41 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
std::vector<string> optimizations;
OP_REQUIRES_OK(
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
Dataset* dataset = new Dataset(ctx, input, optimizations, output_types_,
output_shapes_, optimizer_configs_);
Status s = dataset->Optimize(ctx);
if (s.ok()) {
*output = dataset;
} else {
dataset->Unref();
OP_REQUIRES_OK(ctx, s);
}
auto config_factory = [this, &optimizations]() {
return CreateConfig(optimizations, optimization_configs_);
};
OP_REQUIRES_OK(ctx,
RewriteDataset(ctx, input, std::move(config_factory),
/*optimize_function_library=*/true, output));
}
private:
class Dataset : public GraphRewriteDataset {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const std::vector<string>& optimizations,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
const std::vector<string>& optimizer_configs)
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
optimizations_(optimizations),
optimizer_configs_(optimizer_configs) {}
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
private:
RewriterConfig CreateGrapplerRewriteConfig() override {
RewriterConfig rewriter_config;
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
auto* custom_optimizations_list =
(*custom_optimizer->mutable_parameter_map())["optimizers"]
.mutable_list();
for (const auto& opt : optimizations_) {
custom_optimizations_list->add_s(opt);
}
auto* config_list =
(*custom_optimizer->mutable_parameter_map())["optimizer_configs"]
.mutable_list();
for (const auto& config : optimizer_configs_) {
config_list->add_s(config);
}
return rewriter_config;
static RewriterConfig CreateConfig(
std::vector<string> optimizations,
std::vector<string> optimizations_configs) {
RewriterConfig rewriter_config;
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
auto* custom_optimizations_list =
(*custom_optimizer->mutable_parameter_map())["optimizers"]
.mutable_list();
for (const auto& opt : optimizations) {
custom_optimizations_list->add_s(opt);
}
auto* config_list =
(*custom_optimizer->mutable_parameter_map())["optimizer_configs"]
.mutable_list();
for (const auto& config : optimizations_configs) {
config_list->add_s(config);
}
return rewriter_config;
}
const std::vector<string> optimizations_;
const std::vector<string> optimizer_configs_;
};
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
std::vector<string> optimizer_configs_;
std::vector<string> optimization_configs_;
};
REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册