optimize_dataset_op.cc 14.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <map>

#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
20
#include "tensorflow/core/framework/dataset.h"
21
#include "tensorflow/core/framework/device_base.h"
22
#include "tensorflow/core/framework/function_handle_cache.h"
23 24 25 26 27 28 29 30
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
31
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
32
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
33
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
34
#include "tensorflow/core/lib/core/refcount.h"
35 36 37 38 39
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"

namespace tensorflow {
40
namespace data {
41 42
namespace {

43
// See documentation in ../../ops/dataset_ops.cc for a high-level
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
// description of the following op.
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
 public:
  explicit OptimizeDatasetOp(OpKernelConstruction* ctx)
      : UnaryDatasetOpKernel(ctx),
        graph_def_version_(ctx->graph_def_version()) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
  }

 protected:
  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                   DatasetBase** output) override {
    std::vector<string> optimizations;
    OP_REQUIRES_OK(
        ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
    Dataset* dataset =
61
        new Dataset(ctx, input, optimizations, output_types_, output_shapes_);
62 63 64 65 66 67 68
    Status s = dataset->Optimize(ctx);
    if (s.ok()) {
      *output = dataset;
    } else {
      dataset->Unref();
      OP_REQUIRES_OK(ctx, s);
    }
69 70 71
  }

 private:
72
  class Dataset : public DatasetBase {
73
   public:
74 75
    Dataset(OpKernelContext* ctx, const DatasetBase* input,
            const std::vector<string>& optimizations,
76 77
            const DataTypeVector& output_types,
            const std::vector<PartialTensorShape>& output_shapes)
78
        : DatasetBase(DatasetContext(ctx)),
79
          optimized_input_(nullptr),
80
          input_(input),
81 82
          optimizations_(optimizations),
          output_types_(output_types),
83 84 85
          output_shapes_(output_shapes) {
      input_->Ref();
    }
86

87 88
    ~Dataset() override {
      input_->Unref();
89 90 91
      if (optimized_input_) {
        optimized_input_->Unref();
      }
92
    }
93 94 95

    std::unique_ptr<IteratorBase> MakeIteratorInternal(
        const string& prefix) const override {
96 97 98 99 100
      // We do not add a token for the optimization dataset to the prefix. The
      // prefix is used to identify checkpoint elements and since the
      // optimization dataset is excluded from the checkpoint, adding a token
      // here would result in invalid checkpoint identifiers.
      return std::unique_ptr<IteratorBase>(new Iterator({this, prefix}));
101 102
    }

103
    Status Optimize(OpKernelContext* ctx) {
104 105 106
      GraphDefBuilder b;
      DatasetGraphDefBuilder db(&b);
      Node* input_node = nullptr;
107
      SerializationContext::Params params;
108
      std::vector<std::pair<string, Tensor>> input_list;
109
      params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
110
      params.input_list = &input_list;
111
      params.optimization_only = true;
112 113 114
      SerializationContext serialization_ctx(params);
      TF_RETURN_IF_ERROR(
          db.AddInputDataset(&serialization_ctx, input_, &input_node));
115
      string output_node = input_node->name();
P
Piotr Padlewski 已提交
116

117 118
      GraphDef graph_def;
      TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
119
      VLOG(3) << "Before optimization: " << graph_def.DebugString();
P
Piotr Padlewski 已提交
120

121
      TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
122
      VLOG(3) << "After optimization: " << graph_def.DebugString();
P
Piotr Padlewski 已提交
123 124 125 126 127

      // Instantiate the optimized input pipeline by running the optimized graph
      // using the optimized function library.
      TF_RETURN_IF_ERROR(
          ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_));
128

129 130 131
      // Create a FunctionHandleCache.
      function_handle_cache_.reset(new FunctionHandleCache(lib_));

132 133 134 135 136 137 138 139 140
      // Some functions may have been modified without having their names
      // changed (for example, nested dataset graphs from FlatMap or
      // Interleave). To avoid name conflicts, we remove these functions from
      // flib_def_ before adding the optimized function library.
      for (const FunctionDef& fd : graph_def.library().function()) {
        if (flib_def_->Find(fd.signature().name()) != nullptr) {
          TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fd.signature().name()));
        }
      }
P
Piotr Padlewski 已提交
141 142
      TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph_def.library()));

143 144 145
      Graph graph(OpRegistry::Global());
      TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
      std::vector<Tensor> outputs;
146
      GraphRunner graph_runner(ctx->function_library()->device());
P
Piotr Padlewski 已提交
147 148

      TF_RETURN_IF_ERROR(
149
          graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs));
150 151 152
      TF_RETURN_IF_ERROR(
          GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
      optimized_input_->Ref();
153 154 155 156 157 158 159 160 161 162 163 164
      return Status::OK();
    }

    const DataTypeVector& output_dtypes() const override {
      return output_types_;
    }
    const std::vector<PartialTensorShape>& output_shapes() const override {
      return output_shapes_;
    }

    string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }

165 166
    int64 Cardinality() const override { return input_->Cardinality(); }

167
   protected:
168 169
    Status AsGraphDefInternal(SerializationContext* ctx,
                              DatasetGraphDefBuilder* b,
170
                              Node** output) const override {
171 172 173
      // We only serialize the optimized dataset to avoid re-running
      // optimizations when the input pipeline is restored from a checkpoint.
      TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output));
174 175 176
      return Status::OK();
    }

177 178 179 180 181 182 183
   private:
    class Iterator : public DatasetIterator<Dataset> {
     public:
      explicit Iterator(const Params& params)
          : DatasetIterator<Dataset>(params) {}

      Status Initialize(IteratorContext* ctx) override {
184
        IteratorContext::Params params(ctx);
P
Piotr Padlewski 已提交
185
        params.lib = dataset()->lib_;
186
        params.function_handle_cache = dataset()->function_handle_cache_.get();
187
        return dataset()->optimized_input_->MakeIterator(
188
            IteratorContext(std::move(params)), prefix(), &input_impl_);
189 190 191 192 193
      }

      Status GetNextInternal(IteratorContext* ctx,
                             std::vector<Tensor>* out_tensors,
                             bool* end_of_sequence) override {
194
        IteratorContext::Params params(ctx);
P
Piotr Padlewski 已提交
195
        params.lib = dataset()->lib_;
196
        params.function_handle_cache = dataset()->function_handle_cache_.get();
197 198
        return input_impl_->GetNext(IteratorContext(std::move(params)),
                                    out_tensors, end_of_sequence);
199 200 201
      }

     protected:
202 203 204 205 206 207
      std::shared_ptr<model::Node> CreateNode(
          IteratorContext* ctx, model::Node::Args args) const override {
        return model::MakeKnownRatioNode(std::move(args),
                                         /*ratio=*/1);
      }

208
      Status SaveInternal(IteratorStateWriter* writer) override {
J
Jiri Simsa 已提交
209
        TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
210
        return Status::OK();
211
      }
212 213 214

      Status RestoreInternal(IteratorContext* ctx,
                             IteratorStateReader* reader) override {
J
Jiri Simsa 已提交
215
        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
216 217 218 219 220
        return Status::OK();
      }

     private:
      std::unique_ptr<IteratorBase> input_impl_;
221 222
    };

223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    void AddFakeSinks(FunctionDef* function_def) {
      int counter = 0;
      for (const auto& output : function_def->signature().output_arg()) {
        NodeDef* node = function_def->add_node_def();
        tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
            strings::StrCat("FakeSink", counter++), function_def, node);
        node->set_op("Identity");
        node->add_input(function_def->ret().at(output.name()));
        (*node->mutable_attr())["T"].set_type(output.type());

        (*function_def->mutable_ret())[output.name()] =
            strings::StrCat(node->name(), ":output:0");
      }
    }

    void RemoveFakeSinks(FunctionDef* function_def) {
      // Map from identity node names to their input tensor strings
      std::map<string, string> identity_map;
      for (const auto& node : function_def->node_def()) {
        if (node.op() == "Identity" && node.input_size() == 1) {
          identity_map[node.name()] = node.input(0);
        }
      }
      for (const auto& output_arg : function_def->signature().output_arg()) {
        const string& tensor = function_def->ret().at(output_arg.name());
        const string& output_node = tensor.substr(0, tensor.find(':'));
        if (identity_map.find(output_node) != identity_map.end()) {
          (*function_def->mutable_ret())[output_arg.name()] =
              identity_map.at(output_node);
        }
      }
    }

256 257
    Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
                              string* output_node) {
258 259 260
      // Add an identity node as the fetch node, otherwise we might get
      // 'placeholder is both fed and fetched' errors in some cases when using
      // input list with placeholder dataset nodes.
261
      NodeDef* node = graph_def->mutable_node()->Add();
262 263 264
      tensorflow::grappler::graph_utils::SetUniqueGraphNodeName(
          "Sink", graph_def, node);
      node->set_op("Identity");
265
      node->add_input(*output_node);
266 267
      (*node->mutable_attr())["T"].set_type(DT_VARIANT);
      *output_node = node->name();
268

269 270 271 272 273 274 275 276 277
      // Add fake sink node to graph and functions to allow rewriting the actual
      // sink nodes.
      // TODO(b/118820916): When MetaOptimizer adds provisions for function
      // retvals to be optimizable, we will no longer need this.
      for (auto& function_def :
           *graph_def->mutable_library()->mutable_function()) {
        AddFakeSinks(&function_def);
      }

278 279 280 281 282 283 284
      // Create metagraph.
      MetaGraphDef meta_graph_def;
      (*meta_graph_def.mutable_graph_def()) = *graph_def;

      // Grappler determines fetch ops from collection 'train_op'.
      CollectionDef collection_def;
      auto node_list = collection_def.mutable_node_list();
285
      node_list->add_value(*output_node);
286 287 288
      (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;

      // Create Grappler item.
289 290 291
      tensorflow::ConfigProto config;
      RewriterConfig& rewriter_config =
          *config.mutable_graph_options()->mutable_rewrite_options();
292 293 294
      for (const string& optimization : optimizations_) {
        rewriter_config.add_optimizers(optimization);
      }
295 296 297 298
      // If no optimizations were specified, supply a non-existent
      // optimization to prevent Grappler from applying the default set of
      // optimizations as some of them do not work out of the box at the
      // moment (e.g. because we have no cost model for dataset ops).
299 300
      if (optimizations_.empty()) {
        rewriter_config.add_optimizers("non-existent");
301 302 303 304 305 306 307
      } else {
        // If we apply custom dataset optimizers, explicitly trigger a subset of
        // standard grappler optimizations to further optimize modified dataset
        // graphs (e.g. performing constant folding on merged functions,
        // removing unused graph nodes)
        // TODO(b/118175421): This should be part of the tf.data optimization
        // pass manager.
308 309 310
        // TODO(b/120437209): Apply `constfold` optimization when it is fixed.
        for (const auto& optimizer :
             {"pruning", "function", "shape", "arithmetic", "dependency"}) {
311 312
          rewriter_config.add_optimizers(optimizer);
        }
313 314 315 316 317 318 319 320 321 322
      }
      tensorflow::grappler::ItemConfig item_config;
      item_config.apply_optimizations = true;
      std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
          tensorflow::grappler::GrapplerItemFromMetaGraphDef(
              "graph", meta_graph_def, item_config);
      std::unordered_map<string, tensorflow::DeviceProperties> device_map;
      tensorflow::grappler::VirtualCluster cluster(device_map);

      // Run optimizer.
323 324 325 326 327 328
      if (VLOG_IS_ON(2)) {
        LOG(INFO) << "Performing the following optimizations:";
        for (const string& optimization : optimizations_) {
          LOG(INFO) << "  " << optimization;
        }
      }
329
      TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
330
          *grappler_item, config, ctx->device(), &cluster, graph_def));
331

332 333 334 335 336 337 338 339
      // Remove fake sinks after optimizations are done.
      // TODO(b/118820916): When MetaOptimizer adds provisions for function
      // retvals to be optimizable, we will no longer need this.
      for (auto& function_def :
           *graph_def->mutable_library()->mutable_function()) {
        RemoveFakeSinks(&function_def);
      }

340 341 342
      return Status::OK();
    }

343
    DatasetBase* optimized_input_;
P
Piotr Padlewski 已提交
344 345 346
    FunctionLibraryRuntime* lib_ = nullptr;
    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_ = nullptr;
    std::unique_ptr<FunctionLibraryDefinition> flib_def_ = nullptr;
347
    std::unique_ptr<FunctionHandleCache> function_handle_cache_ = nullptr;
348
    const DatasetBase* input_;
349 350 351 352 353 354 355 356 357 358 359 360 361 362
    const std::vector<string> optimizations_;
    const DataTypeVector output_types_;
    const std::vector<PartialTensorShape> output_shapes_;
  };

  const int graph_def_version_;
  DataTypeVector output_types_;
  std::vector<PartialTensorShape> output_shapes_;
};

REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
                        OptimizeDatasetOp);

}  // namespace
363
}  // namespace data
364
}  // namespace tensorflow