提交 0f2db739 编写于 作者: P Peter Hawkins 提交者: TensorFlower Gardener

[TF:XLA] Split union-find implementation in mark_for_compilation_pass.cc into...

[TF:XLA] Split union-find implementation in mark_for_compilation_pass.cc into a separate library, make it more generic.

PiperOrigin-RevId: 157850985
上级 d5421cf5
......@@ -202,6 +202,7 @@ cc_library(
deps = [
":common",
":graph_to_functiondef",
":union_find",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/kernels:xla_local_launch_op",
......@@ -221,6 +222,11 @@ cc_library(
],
)
cc_library(
name = "union_find",
hdrs = ["union_find.h"],
)
cc_test(
name = "compilation_passes_test",
size = "small",
......
......@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
......@@ -206,70 +207,12 @@ Status FindCompilationCandidates(
return Status::OK();
}
// Union-Find data structure used to compute clusters. We use our own
// implementation because we want one key feature: when merging clusters, we
// need to know which value becomes the representative of the merged clusters.
// We use the representatives to name nodes in a cycle detection graph, and we
// need to control which node is named.
// TODO(phawkins): consider merging this code with union-find implementations
// in Tensorflow, e.g., in SimplePlacer.
class Cluster {
public:
Cluster();
int Size() { return FindRoot()->size_; }
// Merges this cluster with 'other'. This cluster's representative becomes
// the representative of the merged cluster; the representative of 'other'
// is ignored.
void Merge(Cluster* other);
// Each cluster has an associated integer 'representative', initialized to -1
// by default.
int GetRepresentative() { return FindRoot()->representative_; }
void SetRepresentative(int representative) {
FindRoot()->representative_ = representative;
}
private:
// Finds the root element of the cluster. Performs path compression.
Cluster* FindRoot();
int representative_;
int rank_;
int size_; // Size of the cluster.
Cluster* parent_;
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.
int representative = -1;
};
Cluster::Cluster()
: representative_(-1), rank_(0), size_(1), parent_(nullptr) {}
void Cluster::Merge(Cluster* other) {
Cluster* a = FindRoot();
Cluster* b = other->FindRoot();
if (a == b) return;
if (a->rank_ > b->rank_) {
b->parent_ = a;
a->size_ += b->size_;
return;
}
a->parent_ = b;
if (a->rank_ == b->rank_) {
b->rank_++;
}
b->representative_ = a->representative_;
b->size_ += a->size_;
}
Cluster* Cluster::FindRoot() {
if (!parent_) return this;
// Path compression: update intermediate nodes to point to the root of the
// equivalence class.
parent_ = parent_->FindRoot();
return parent_;
}
} // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
......@@ -432,10 +375,11 @@ Status MarkForCompilationPass::RunImpl(
// Each compilation candidate belongs to a cluster. The cluster's
// representative
// names the node in the 'cycles' graph that represents the cluster.
std::vector<Cluster> clusters(graph->num_node_ids());
std::deque<Cluster*> worklist;
std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids());
std::deque<UnionFind<Cluster>*> worklist;
for (Node* node : compilation_candidates) {
clusters[node->id()].SetRepresentative(node->id());
Cluster& cluster = clusters[node->id()].Get();
cluster.representative = node->id();
worklist.push_back(&clusters[node->id()]);
}
......@@ -445,7 +389,7 @@ Status MarkForCompilationPass::RunImpl(
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
while (!worklist.empty()) {
int from = worklist.front()->GetRepresentative();
int from = worklist.front()->Get().representative;
worklist.pop_front();
Node* node_from = graph->FindNodeId(from);
......@@ -518,7 +462,7 @@ Status MarkForCompilationPass::RunImpl(
// Count the number of elements in each cluster.
std::vector<int> cluster_sizes(graph->num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative();
int cluster = clusters[n->id()].Get().representative;
cluster_sizes[cluster]++;
}
......@@ -532,7 +476,7 @@ Status MarkForCompilationPass::RunImpl(
// if compilation is enabled, otherwise there will be no such candidates).
const int min_cluster_size = flags->tf_xla_min_cluster_size;
for (Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative();
int cluster = clusters[n->id()].Get().representative;
// Compile if the user marked this node _XlaCompile=true
bool compile_attr = false;
......
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
namespace tensorflow {
// Union-Find data structure.
// Each cluster has an associated value; when merging clusters we can control
// which value becomes the representative of the merged clusters. Values must be
// copyable.
template <typename T>
class UnionFind {
public:
UnionFind() : rank_(0), size_(1), parent_(nullptr) {}
// Returns the number of elements in a cluster.
int Size() { return FindRoot()->size_; }
// Merges this cluster with 'other'. This cluster's value becomes
// the value of the merged cluster; the value of 'other' is ignored.
void Merge(UnionFind* other);
// Each cluster has an associated value. Retrieves the value associated
// with this cluster.
T& Get() { return FindRoot()->value_; }
private:
// Finds the root element of the cluster. Performs path compression.
UnionFind* FindRoot();
int rank_;
int size_; // Size of the cluster.
UnionFind* parent_;
T value_;
};
template <typename T>
void UnionFind<T>::Merge(UnionFind* other) {
UnionFind<T>* a = FindRoot();
UnionFind<T>* b = other->FindRoot();
if (a == b) return;
if (a->rank_ > b->rank_) {
b->parent_ = a;
a->size_ += b->size_;
return;
}
a->parent_ = b;
if (a->rank_ == b->rank_) {
b->rank_++;
}
b->value_ = a->value_;
b->size_ += a->size_;
}
template <typename T>
UnionFind<T>* UnionFind<T>::FindRoot() {
if (!parent_) return this;
// Path compression: update intermediate nodes to point to the root of the
// equivalence class.
parent_ = parent_->FindRoot();
return parent_;
}
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册