提交 61bd2c65 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[lite] When constructing a subgraph, use control dependencies from the model's...

[lite] When constructing a subgraph, use control dependencies from the model's metadata, if present.

PiperOrigin-RevId: 480984280
上级 164d5975
......@@ -415,6 +415,10 @@ if(TFLITE_ENABLE_EXTERNAL_DELEGATE)
FILTER ".*(_test|_tester)\\.(cc|h)"
)
endif()
populate_tflite_source_vars("experimental/remat"
TFLITE_EXPERIMENTAL_REMAT_SRCS
FILTER ".*_test\\.(cc|h)$"
)
if (TFLITE_ENABLE_RESOURCE)
populate_tflite_source_vars("experimental/resource"
TFLITE_EXPERIMENTAL_RESOURCE_SRCS
......@@ -484,6 +488,7 @@ set(_ALL_TFLITE_SRCS
${TFLITE_DELEGATES_SRCS}
${TFLITE_DELEGATES_XNNPACK_SRCS}
${TFLITE_DELEGATES_EXTERNAL_SRCS}
${TFLITE_EXPERIMENTAL_REMAT_SRCS}
${TFLITE_EXPERIMENTAL_RESOURCE_SRCS}
${TFLITE_EXPERIMENTAL_RUY_PROFILER_SRCS}
${TFLITE_EXPERIMENTAL_RUY_SRCS}
......
......@@ -33,6 +33,7 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/core/api:error_reporter",
"//tensorflow/lite/experimental/remat:metadata_util",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/internal:signature_def",
"//tensorflow/lite/profiling:root_profiler",
......@@ -71,6 +72,7 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/core/api:error_reporter",
"//tensorflow/lite/experimental/remat:metadata_util",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/internal:signature_def",
"//tensorflow/lite/profiling:root_profiler",
......@@ -115,6 +117,7 @@ cc_library(
"//tensorflow/lite/c:common_internal",
"//tensorflow/lite/core/api",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/experimental/remat:metadata_util",
"//tensorflow/lite/profiling:root_profiler",
"//tensorflow/lite/schema:schema_fbs",
] + select({
......
......@@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/profiler.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/remat/metadata_util.h"
#include "tensorflow/lite/experimental/resource/initialization_status.h"
#include "tensorflow/lite/experimental/resource/resource_base.h"
#include "tensorflow/lite/external_cpu_backend_context.h"
......@@ -908,6 +909,14 @@ class Interpreter {
// InterpreterOptions object which is being used.
std::unique_ptr<InterpreterOptions> options_;
// Stores control edges that are encoded in the metadata of the model. Updated
// in SetMetadata; model_control_dependencies_.empty() means that there were
// no control dependencies encoded in the metadata, or that we were unable to
// parse them. We assume that, if we were able to parse them, they are
// consistent with the model and no further consistency check (e.g., bounds
// checks when dereferencing by subgraph and operator index) will take place.
ModelControlDependencies model_control_dependencies_;
};
} // namespace tflite
......
......@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/lite/core/api/profiler.h"
#include "tensorflow/lite/core/api/tensor_utils.h"
#include "tensorflow/lite/core/macros.h"
#include "tensorflow/lite/experimental/remat/metadata_util.h"
#include "tensorflow/lite/experimental/resource/resource_base.h"
#include "tensorflow/lite/graph_info.h"
#include "tensorflow/lite/memory_planner.h"
......@@ -400,6 +401,14 @@ void PopulatePreviewDelegateParams(const NodeSubset& node_subset,
} // namespace
TfLiteStatus Subgraph::PartitionGraph(const TfLiteIntArray* nodes_to_replace,
std::vector<NodeSubset>* node_subsets) {
const InterpreterInfo info(this);
return PartitionGraphIntoIndependentNodeSubsets(
&info, nodes_to_replace, node_subsets,
/*greedily=*/!DisableDelegateClustering(), control_edges_);
}
TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels(
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegate* delegate) {
......@@ -413,11 +422,10 @@ TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels(
// Analyze the graph to find all independent node_subsets that are either
// fully not-this-delegate or this-delegate computation.
InterpreterInfo info(this);
std::vector<NodeSubset> node_subsets;
PartitionGraphIntoIndependentNodeSubsets(
&info, nodes_to_replace, &node_subsets,
/*greedily=*/!DisableDelegateClustering());
if (PartitionGraph(nodes_to_replace, &node_subsets) == kTfLiteError) {
return kTfLiteError;
}
// On Android the log message below is used for diagnosing delegation success
// also in production builds. Use VERBOSE here so that the logging is turned
......@@ -572,11 +580,10 @@ TfLiteStatus Subgraph::PreviewDelegatePartitioning(
}
// Partition the execution plan into node subsets.
InterpreterInfo info(this);
std::vector<NodeSubset> node_subsets;
PartitionGraphIntoIndependentNodeSubsets(
&info, nodes_to_replace, &node_subsets,
/*greedily=*/!DisableDelegateClustering());
if (PartitionGraph(nodes_to_replace, &node_subsets) == kTfLiteError) {
return kTfLiteError;
}
// Create one TfLiteDelegateParams per node-subset which would be delegated.
for (auto& node_subset : node_subsets) {
......@@ -623,8 +630,10 @@ TfLiteStatus Subgraph::SetVariables(std::vector<int> variables) {
}
TfLiteStatus Subgraph::SetMetadata(
const std::map<std::string, std::string>* metadata) {
const std::map<std::string, std::string>* metadata,
const ControlEdges* control_edges) {
metadata_ = metadata;
control_edges_ = control_edges;
return kTfLiteOk;
}
......
......@@ -591,6 +591,18 @@ class Subgraph {
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegate* delegate);
// Helper method for PreviewDelegatePartitioning and
// ReplaceNodeSubsetsWithDelegateKernels. Creates node subsets whose members
// are either all present in or all absent from *nodes_to_replace. The
// NodeSubsets and their members are in schedulable order, where
// schedulability considers data dependencies and, if present, *control_edges_
// between nodes.
// If control_edges_ == nullptr, PartitionGraph will preserve the original
// execuion order of nodes with OpMightHaveSideEffect() when finding
// schedulable orderings.
TfLiteStatus PartitionGraph(const TfLiteIntArray* nodes_to_replace,
std::vector<NodeSubset>* node_subsets);
// WARNING: This is an experimental interface that is subject to change.
// Gets the internal pointer to a TensorFlow lite node by node_index.
TfLiteStatus GetNodeAndRegistration(int node_index, TfLiteNode** node,
......@@ -732,7 +744,8 @@ class Subgraph {
// Since the lifetime of the Interpreter exceeds the Subgraph, metadata
// remains valid for the latter's lifetime.
// Also sets relevant fields on context_ based on known metadata.
TfLiteStatus SetMetadata(const std::map<std::string, std::string>* metadata);
TfLiteStatus SetMetadata(const std::map<std::string, std::string>* metadata,
const ControlEdges* control_edges = nullptr);
// Initializes the mapping between tensor index to the index of the
// last operation that uses the tensor as input.
......@@ -925,6 +938,13 @@ class Subgraph {
// `InterpreterOptions` object which is being used and owned by Interpreter.
InterpreterOptions* options_;
// Control edges (i.e., dependencies between nodes in addition to their data
// dependencies); can be nullptr. Will be initialized from metadata associated
// with the owning interpreter; the pointee is owned by the owning
// interpreter. The owning interpreter will keep this consistent with
// metadata_ by appropriately parametrized SetMetadata method calls.
const ControlEdges* control_edges_ = nullptr;
};
} // namespace tflite
......
......@@ -133,6 +133,7 @@ cc_test(
"//tensorflow/lite/c:c_api_experimental",
"//tensorflow/lite/c:c_api_types",
"//tensorflow/lite/c:common",
"//tensorflow/lite/experimental/remat:metadata_util",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels/internal:compatibility",
......
......@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/delegate_test_util.h"
#include "tensorflow/lite/experimental/remat/metadata_util.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
......@@ -43,6 +44,7 @@ namespace delegates {
using test_utils::SimpleDelegate;
using test_utils::TestDelegate;
using test_utils::TestDelegateWithControlEdges;
using test_utils::TestFP16Delegation;
using test_utils::TestTwoDelegates;
......@@ -1354,6 +1356,68 @@ TEST_F(TestReleaseDynamicTensorWithDelegate, ShapePropagation_FlagNotSet) {
ASSERT_EQ(interpreter_->tensor(1)->data.raw, nullptr);
}
// Tests for control edges passed in metadata
// ==========================================
TEST_F(TestDelegateWithControlEdges, NoControlEdges) {
// Put {0,2} on a super-node, if possible
delegate_ = std::make_unique<SimpleDelegate>(std::vector<int>({0, 2}));
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
ASSERT_EQ(interpreter_->execution_plan().size(), 3); // [ {0, 2}, 1, 3]
EXPECT_EQ(interpreter_->execution_plan().data()[0], 4); // new super-node
EXPECT_EQ(interpreter_->execution_plan().data()[1], 1); // undelegated
EXPECT_EQ(interpreter_->execution_plan().data()[2], 3); // undelegated
}
TEST_F(TestDelegateWithControlEdges, OverrideControlEdges) {
// Execute node 1 before node 2.
SetMetadata({{kModelControlDependenciesMetadataKey,
SerializeModelControlDependencies({{{1, 2}}})}});
// Put {0,2} on a super-node, if possible
delegate_ = std::make_unique<SimpleDelegate>(std::vector<int>({0, 2}));
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
// 1 has to be executed before 2, so original execution order is
// preserved. Nodes 0 and 2 both get rewritten into new delegate nodes
// 4 and 5.
ASSERT_EQ(interpreter_->execution_plan().size(), 4); // [ 0, 1, 2, 3]
EXPECT_EQ(interpreter_->execution_plan().data()[0], 4);
EXPECT_EQ(interpreter_->execution_plan().data()[1], 1);
EXPECT_EQ(interpreter_->execution_plan().data()[2], 5);
EXPECT_EQ(interpreter_->execution_plan().data()[3], 3);
}
// Test that empty control edge metadata for subgraph 0 don't change anything.
TEST_F(TestDelegateWithControlEdges, EmptyControlEdges) {
SetMetadata({{kModelControlDependenciesMetadataKey,
SerializeModelControlDependencies({{}})}});
delegate_ = std::make_unique<SimpleDelegate>(std::vector<int>({0, 2}));
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
EXPECT_EQ(interpreter_->execution_plan().size(), 3); // [ {0, 2}, 1, 3]
}
// Test that control edges that are compatible with execution order
// [0, 2, 1, 3] don't change anything (case 1).
TEST_F(TestDelegateWithControlEdges, CompatibleControlEdges1) {
// Execute node 0 before node 2 and node 1 before node 3.
SetMetadata({{kModelControlDependenciesMetadataKey,
SerializeModelControlDependencies({{{0, 2}, {1, 3}}})}});
delegate_ = std::make_unique<SimpleDelegate>(std::vector<int>({0, 2}));
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
EXPECT_EQ(interpreter_->execution_plan().size(), 3); // [ {0, 2}, 1, 3]
}
// Test that control edges that are compatible with execution order
// [0, 2, 1, 3] don't change anything (case 2).
TEST_F(TestDelegateWithControlEdges, CompatibleControlEdges2) {
// Execute node 0 before node 1 and node 1 before node 3.
SetMetadata({{kModelControlDependenciesMetadataKey,
SerializeModelControlDependencies({{{0, 1}, {1, 3}}})}});
delegate_ = std::make_unique<SimpleDelegate>(std::vector<int>({0, 2}));
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
EXPECT_EQ(interpreter_->execution_plan().size(), 3); // [ {0, 2}, 1, 3]
}
// Tests for FP16 graphs
// =====================
......
......@@ -560,6 +560,35 @@ TfLiteRegistration TestFP16Delegation::FP16Delegate::FakeFusedRegistration() {
return reg;
}
void TestDelegateWithControlEdges::SetUpSubgraph(Subgraph* subgraph) {
subgraph->AddTensors(5);
subgraph->SetInputs({0});
subgraph->SetOutputs({4});
std::vector<int> dims({3});
const TfLiteQuantization quant{kTfLiteNoQuantization, nullptr};
subgraph->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
subgraph->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
subgraph->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
subgraph->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
subgraph->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", dims.size(),
dims.data(), quant, false);
TfLiteRegistration reg = AddOpRegistration();
int node_index_ignored;
subgraph->AddNodeWithParameters({0, 0}, {1}, {}, nullptr, 0, nullptr, &reg,
&node_index_ignored);
subgraph->AddNodeWithParameters({1, 1}, {2}, {}, nullptr, 0, nullptr, &reg,
&node_index_ignored);
subgraph->AddNodeWithParameters({1, 1}, {3}, {}, nullptr, 0, nullptr, &reg,
&node_index_ignored);
subgraph->AddNodeWithParameters({2, 3}, {4}, {}, nullptr, 0, nullptr, &reg,
&node_index_ignored);
}
} // namespace test_utils
} // namespace delegates
} // namespace tflite
......@@ -17,7 +17,9 @@ limitations under the License.
#include <stdint.h>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
......@@ -91,6 +93,8 @@ class SimpleDelegate {
// Friend of Interpreter to access private methods.
class TestDelegation {
public:
virtual ~TestDelegation() {}
// Returns an empty interpreter that uses the same default delegates that are
// normally enabled by default.
static std::unique_ptr<Interpreter> NewInterpreterWithDefaultDelegates() {
......@@ -104,8 +108,11 @@ class TestDelegation {
TfLiteStatus RemoveAllDelegates() {
return interpreter_->RemoveAllDelegates();
}
void SetMetadata(const std::map<std::string, std::string>& metadata) {
interpreter_->SetMetadata(metadata);
}
void SetUpSubgraph(Subgraph* subgraph);
virtual void SetUpSubgraph(Subgraph* subgraph);
void AddSubgraphs(int subgraphs_to_add,
int* first_new_subgraph_index = nullptr);
......@@ -126,6 +133,30 @@ class TestDelegate : public TestDelegation, public ::testing::Test {
std::unique_ptr<SimpleDelegate> delegate_, delegate2_;
};
// Tests scenarios involving a single delegate and control edges.
// Subgraph 0 has the form
//
// /---OP2---\
// / \
// >---OP0 OP3--->
// \ /
// \---OP1---/
//
// Delegating OP0, OP2 will generate an execution graph with a "super-node"
// {OP0->OP2}, which can be disabled by adding (in metadata) a control edge
// between OP1 and OP2:
//
// /->-OP2---\
// / ^ \
// >---OP0 ^ OP3--->
// \ ^ /
// \---OP1---/
//
class TestDelegateWithControlEdges : public TestDelegate {
protected:
void SetUpSubgraph(Subgraph* subgraph) override;
};
// Tests scenarios involving two delegates, parametrized by the first & second
// delegate's flags.
class TestTwoDelegates
......
......@@ -9,6 +9,7 @@ package_group(
name = "friends",
packages = [
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/lite/...",
],
)
......
......@@ -414,9 +414,21 @@ TfLiteStatus Interpreter::RemoveAllDelegates() {
TfLiteStatus Interpreter::SetMetadata(
const std::map<std::string, std::string>& metadata) {
metadata_ = metadata;
const auto maybe_model_control_dependencies =
metadata_.find(kModelControlDependenciesMetadataKey);
if (maybe_model_control_dependencies == metadata_.end() ||
!ParseModelControlDependencies(
maybe_model_control_dependencies->second.data(),
maybe_model_control_dependencies->second.size(),
&model_control_dependencies_)) {
model_control_dependencies_.clear();
}
for (int subgraph_index = 0; subgraph_index < subgraphs_.size();
++subgraph_index) {
TF_LITE_ENSURE_STATUS(subgraphs_[subgraph_index]->SetMetadata(&metadata_));
TF_LITE_ENSURE_STATUS(subgraphs_[subgraph_index]->SetMetadata(
&metadata_, model_control_dependencies_.empty()
? nullptr
: &model_control_dependencies_[subgraph_index]));
}
return kTfLiteOk;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册