提交 f9fbff63 编写于 作者: R Rachel Lim 提交者: TensorFlower Gardener

[tf.data] Add option to control whether vectorization is aggressive (i.e....

[tf.data] Add option to control whether vectorization is aggressive (i.e. always vectorizes) or safe (i.e. uses ChooseFastestBranchDataset)

PiperOrigin-RevId: 239203789
上级 f6dfeeec
......@@ -548,6 +548,7 @@ cc_library(
hdrs = ["meta_optimizer.h"],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
......
......@@ -534,44 +534,66 @@ Status MapVectorization::OptimizeAndCollectStats(Cluster* cluster,
AddVectorizedFunction(*map_node, *map_func, library);
CHECK_NOTNULL(vectorized_func);
std::vector<const NodeDef*> vectorized_branch;
NodeDef* new_batch_node;
TF_RETURN_IF_ERROR(AddNewBatchNode(
*batch_node, *input_node, *vectorized_func, &graph, &new_batch_node));
vectorized_branch.push_back(new_batch_node);
NodeDef* new_map_node;
TF_RETURN_IF_ERROR(AddNewMapNode(*map_node, *batch_node, *new_batch_node,
*vectorized_func, &graph, &new_map_node));
vectorized_branch.push_back(new_map_node);
NodeDef* optional_new_prefetch_node = nullptr;
if (optional_prefetch_node) {
// If the original pipeline was .map().prefetch().batch(), the new
// pipeline is .batch().map().prefetch()
NodeDef* new_prefetch_node;
TF_RETURN_IF_ERROR(AddNewPrefetchNode(*optional_prefetch_node,
*batch_node, *new_map_node, &graph,
&new_prefetch_node));
vectorized_branch.push_back(new_prefetch_node);
&optional_new_prefetch_node));
}
std::vector<const NodeDef*> vectorized_branch(
{new_batch_node, new_map_node});
std::vector<const NodeDef*> original_branch({map_node});
if (optional_prefetch_node) {
original_branch.push_back(optional_prefetch_node);
vectorized_branch.push_back(optional_new_prefetch_node);
}
if (map_node->op() != kExperimentalMapAndBatchOp) {
if (batch_node->op() != kExperimentalMapAndBatchOp) {
original_branch.push_back(batch_node);
}
NodeDef* new_choose_fastest_node;
TF_RETURN_IF_ERROR(AddNewChooseFastestNode(
input_node, /*ratio_numerator_name=*/new_batch_node->input(1),
std::move(original_branch), std::move(vectorized_branch), &graph,
library, &new_choose_fastest_node));
// Mark the original nodes for deletion.
for (const auto& n : original_branch) {
nodes_to_delete.insert(n->name());
}
if (use_choose_fastest_) {
// Optionally, use ChooseFastestBranch node to mitigate potential
// regressions caused by vectorization.
for (const auto& n : vectorized_branch) {
// Mark the vectorized nodes for deletion, since they will be added in
// the choose fastest dataset branch function separately.
nodes_to_delete.insert(n->name());
}
NodeDef* new_choose_fastest_node;
TF_RETURN_IF_ERROR(AddNewChooseFastestNode(
input_node, /*ratio_numerator_name=*/new_batch_node->input(1),
std::move(original_branch), std::move(vectorized_branch), &graph,
library, &new_choose_fastest_node));
// Make output of Batch point to ChooseFastest instead.
TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_node->name(),
new_choose_fastest_node->name()));
} else {
// Make output of Batch point to the new Map (or Prefetch) node instead.
TF_RETURN_IF_ERROR(graph.UpdateFanouts(
batch_node->name(), optional_new_prefetch_node
? optional_new_prefetch_node->name()
: new_map_node->name()));
}
// Make output of Batch point to ChooseFastest instead.
TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_node->name(),
new_choose_fastest_node->name()));
TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
stats->num_changes++;
}
return Status::OK();
......
......@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
namespace tensorflow {
......@@ -33,10 +34,11 @@ namespace grappler {
// (or map_and_batch)
//
// To:
// input --> map --> batch --------+
// | (or map_and_batch) |
// | v
// +-----> batch --> map --> choose_fastest --> output
// input --> batch --> map --> output
//
// If the "ChooseFastest" configuration is enabled, it adds a
// ChooseFastestBranch dataset node to pick between the original map->batch
// branch and the vectorized batch->map branch.
//
class MapVectorization : public TFDataOptimizerBase {
public:
......@@ -47,6 +49,19 @@ class MapVectorization : public TFDataOptimizerBase {
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
if (!config) return Status::OK();
const string& choose_fastest_param =
config->parameter_map().at("use_choose_fastest").s();
if (choose_fastest_param == "true") {
use_choose_fastest_ = true;
} else if (choose_fastest_param == "false") {
use_choose_fastest_ = false;
} else {
return errors::Internal(
"Received an invalid value for parameter \"use_choose_fastest\"",
choose_fastest_param);
}
return Status::OK();
}
......@@ -56,6 +71,9 @@ class MapVectorization : public TFDataOptimizerBase {
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
private:
bool use_choose_fastest_ = false;
};
} // namespace grappler
......
......@@ -53,6 +53,19 @@ constexpr char kAttrNameDtype[] = "dtype";
using test::function::NDef;
Status OptimizeWithMapVectorization(const GrapplerItem& item, GraphDef* output,
bool use_choose_fastest) {
MapVectorization optimizer;
RewriterConfig_CustomGraphOptimizer config;
if (use_choose_fastest) {
(*config.mutable_parameter_map())["use_choose_fastest"].set_s("true");
} else {
(*config.mutable_parameter_map())["use_choose_fastest"].set_s("false");
}
TF_RETURN_IF_ERROR(optimizer.Init(&config));
return optimizer.Optimize(nullptr, item, output);
}
// Adds a simple vectorizable map function that is akin to
// dataset.map(lambda x: tf.identity(x))
FunctionDef* AddMapFn(MutableGraphView* graph) {
......@@ -188,6 +201,35 @@ const FunctionDef* GetFunction(const GraphDef& graph,
return &graph.library().function(found);
}
void CheckVectorizedWithoutChooseFastest(
const GraphDef& output, gtl::ArraySlice<string> expected_vectorized_branch,
const string& input_name) {
std::vector<const NodeDef*> vectorized_branch;
for (const auto& op : expected_vectorized_branch) {
// This assumes that vectorized op is the only one that exists in the graph.
// For our test cases, this is true (we don't have superfluous map/batch
// nodes in other parts of the pipeline).
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(op, output).size(), 1);
vectorized_branch.push_back(
&output.node(graph_utils::FindGraphNodeWithOp(op, output)));
}
for (int i = 1; i < vectorized_branch.size() - 1; ++i) {
const NodeDef* node = vectorized_branch[i];
const NodeDef* next_node = vectorized_branch[i + 1];
ASSERT_EQ(next_node->input(0), node->name());
}
ASSERT_EQ(vectorized_branch[0]->input(0), input_name);
const NodeDef* vectorized_map_node = vectorized_branch[1];
string function_name =
vectorized_map_node->attr().at(kAttrNameF).func().name();
const FunctionDef* function = GetFunction(output, function_name);
ASSERT_NE(function, nullptr);
EXPECT_EQ(function->node_def(0).op(), "Identity");
}
// Checks that a graph has undergone the map_vectorization transformation
// successfully, whereby the new graph has the shape:
//
......@@ -198,10 +240,15 @@ const FunctionDef* GetFunction(const GraphDef& graph,
// |
// +--> old map --> old batch
//
void CheckVectorized(const GraphDef& output,
gtl::ArraySlice<string> expected_vectorized_branch,
gtl::ArraySlice<string> expected_original_branch,
const string& input_name) {
void CheckVectorizedWithChooseFastest(
const GraphDef& output, gtl::ArraySlice<string> expected_vectorized_branch,
gtl::ArraySlice<string> expected_original_branch,
const string& input_name) {
for (const auto& op : {kBatchOp, kBatchV2Op, kMapOp, kParallelMapOp,
kExperimentalMapAndBatchOp}) {
// Check that the dataset nodes have been removed from the main graph.
ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(op, output).size(), 0);
}
ASSERT_EQ(
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
const NodeDef& choose_fastest_node =
......@@ -234,12 +281,13 @@ void CheckVectorized(const GraphDef& output,
}
class MapThenBatchTest
: public ::testing::TestWithParam<std::tuple<int, bool, int>> {};
: public ::testing::TestWithParam<std::tuple<int, bool, int, bool>> {};
TEST_P(MapThenBatchTest, IsVectorized) {
int num_parallel_calls = std::get<0>(GetParam());
bool use_batch_v2 = std::get<1>(GetParam());
int prefetch = std::get<2>(GetParam());
bool use_choose_fastest = std::get<3>(GetParam());
GrapplerItem item;
MutableGraphView graph(&item.graph);
auto range_dataset = AddRangeNode(&graph);
......@@ -251,9 +299,8 @@ TEST_P(MapThenBatchTest, IsVectorized) {
dataset = AddPrefetchNode(&graph, dataset->name(), prefetch);
}
dataset = AddBatchNode(&graph, dataset->name(), use_batch_v2);
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
std::vector<string> expected_original_branch;
expected_original_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
......@@ -272,14 +319,24 @@ TEST_P(MapThenBatchTest, IsVectorized) {
expected_vectorized_branch.push_back(kPrefetchOp);
}
CheckVectorized(output, expected_vectorized_branch, expected_original_branch,
range_dataset->name());
if (use_choose_fastest) {
CheckVectorizedWithChooseFastest(output, expected_vectorized_branch,
expected_original_branch,
range_dataset->name());
} else {
CheckVectorizedWithoutChooseFastest(output, expected_vectorized_branch,
range_dataset->name());
}
}
INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest,
::testing::Combine(::testing::Values(0, 12),
::testing::Bool(),
::testing::Values(0, 20)));
::testing::Values(0, 20),
::testing::Bool()));
class MapAndBatchTest : public ::testing::TestWithParam<bool> {};
NodeDef* AddMapAndBatchNode(MutableGraphView* graph,
const string& input_dataset, const string& map_fn,
......@@ -307,7 +364,7 @@ NodeDef* AddMapAndBatchNode(MutableGraphView* graph,
return graph->AddNode(std::move(result));
}
TEST(MapVectorizationTest, VectorizeExperimentalMapAndBatch) {
TEST_P(MapAndBatchTest, VectorizeExperimentalMapAndBatch) {
GrapplerItem item;
MutableGraphView graph(&item.graph);
auto range_node = AddRangeNode(&graph);
......@@ -316,16 +373,24 @@ TEST(MapVectorizationTest, VectorizeExperimentalMapAndBatch) {
map_fn->signature().name());
ASSERT_NE(map_and_batch_node, nullptr);
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
bool use_choose_fastest = GetParam();
CheckVectorized(output, {kBatchV2Op, kParallelMapOp},
{kExperimentalMapAndBatchOp}, range_node->name());
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
if (use_choose_fastest) {
CheckVectorizedWithChooseFastest(output, {kBatchV2Op, kParallelMapOp},
{kExperimentalMapAndBatchOp},
range_node->name());
} else {
CheckVectorizedWithoutChooseFastest(output, {kBatchV2Op, kParallelMapOp},
range_node->name());
}
}
INSTANTIATE_TEST_SUITE_P(MapAndBatchTest, MapAndBatchTest, ::testing::Bool());
class ChainedMapAndBatchTest
: public ::testing::TestWithParam<std::tuple<bool, bool>> {};
: public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {};
// Tests:
// 1) map.batch.map.batch
......@@ -352,52 +417,76 @@ TEST_P(ChainedMapAndBatchTest, IsVectorized) {
bool fuse_0 = std::get<0>(GetParam());
bool fuse_1 = std::get<1>(GetParam());
bool use_choose_fastest = std::get<2>(GetParam());
auto map_and_batch_0 = make_map_and_batch(input_node, fuse_0);
auto map_and_batch_1 = make_map_and_batch(map_and_batch_0, fuse_1);
ASSERT_NE(map_and_batch_1, nullptr);
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
TF_ASSERT_OK(TopologicalSort(&output));
std::vector<int> choose_fastest_nodes =
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output);
ASSERT_EQ(choose_fastest_nodes.size(), 2);
std::vector<string> fused_sequence({kExperimentalMapAndBatchOp});
std::vector<string> unfused_sequence({kParallelMapOp, kBatchV2Op});
const NodeDef& range_node =
output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
const NodeDef& choose_fastest_0 = output.node(choose_fastest_nodes[0]);
ASSERT_EQ(choose_fastest_0.input(0), range_node.name());
const NodeDef& choose_fastest_1 = output.node(choose_fastest_nodes[1]);
ASSERT_EQ(choose_fastest_1.input(0), choose_fastest_0.name());
auto check_branches = [&output](const NodeDef& choose_fastest_node,
gtl::ArraySlice<string> original_ops) {
const auto& functions_list =
choose_fastest_node.attr().at("branches").list();
// Branch 0: vectorized
const FunctionDef* branch_0 =
GetFunction(output, functions_list.func(0).name());
ASSERT_NE(branch_0, nullptr);
CheckBranch(*branch_0, {kBatchV2Op, kParallelMapOp});
// Branch 1: original
const FunctionDef* branch_1 =
GetFunction(output, functions_list.func(1).name());
ASSERT_NE(branch_1, nullptr);
CheckBranch(*branch_1, original_ops);
};
check_branches(choose_fastest_0, fuse_0 ? fused_sequence : unfused_sequence);
check_branches(choose_fastest_1, fuse_1 ? fused_sequence : unfused_sequence);
if (use_choose_fastest) {
std::vector<int> choose_fastest_nodes =
graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output);
ASSERT_EQ(choose_fastest_nodes.size(), 2);
std::vector<string> fused_sequence({kExperimentalMapAndBatchOp});
std::vector<string> unfused_sequence({kParallelMapOp, kBatchV2Op});
const NodeDef& range_node =
output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
const NodeDef& choose_fastest_0 = output.node(choose_fastest_nodes[0]);
ASSERT_EQ(choose_fastest_0.input(0), range_node.name());
const NodeDef& choose_fastest_1 = output.node(choose_fastest_nodes[1]);
ASSERT_EQ(choose_fastest_1.input(0), choose_fastest_0.name());
auto check_branches = [&output](const NodeDef& choose_fastest_node,
gtl::ArraySlice<string> original_ops) {
const auto& functions_list =
choose_fastest_node.attr().at("branches").list();
// Branch 0: vectorized
const FunctionDef* branch_0 =
GetFunction(output, functions_list.func(0).name());
ASSERT_NE(branch_0, nullptr);
CheckBranch(*branch_0, {kBatchV2Op, kParallelMapOp});
// Branch 1: original
const FunctionDef* branch_1 =
GetFunction(output, functions_list.func(1).name());
ASSERT_NE(branch_1, nullptr);
CheckBranch(*branch_1, original_ops);
};
check_branches(choose_fastest_0,
fuse_0 ? fused_sequence : unfused_sequence);
check_branches(choose_fastest_1,
fuse_1 ? fused_sequence : unfused_sequence);
} else {
std::vector<int> map_nodes =
graph_utils::FindAllGraphNodesWithOp(kParallelMapOp, output);
std::vector<int> batch_nodes =
graph_utils::FindAllGraphNodesWithOp(kBatchV2Op, output);
ASSERT_EQ(map_nodes.size(), 2);
ASSERT_EQ(batch_nodes.size(), 2);
const NodeDef& range_node =
output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
const NodeDef& batch_node_0 = output.node(batch_nodes[0]);
EXPECT_EQ(batch_node_0.input(0), range_node.name());
const NodeDef& map_node_0 = output.node(map_nodes[0]);
EXPECT_EQ(map_node_0.input(0), batch_node_0.name());
const NodeDef& batch_node_1 = output.node(batch_nodes[1]);
EXPECT_EQ(batch_node_1.input(0), map_node_0.name());
const NodeDef& map_node_1 = output.node(map_nodes[1]);
EXPECT_EQ(map_node_1.input(0), batch_node_1.name());
}
}
INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest,
::testing::Combine(::testing::Bool(),
::testing::Bool(),
::testing::Bool()));
// Not all dataset types have "output_shapes" and "output_types"
......@@ -434,9 +523,8 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputShapes) {
auto map_node =
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
auto batch_node = AddBatchNode(&graph, map_node->name());
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
CheckNotVectorized(output, map_node->op(), batch_node->op(),
input_node->name());
}
......@@ -454,9 +542,8 @@ TEST(MapVectorizationTest, VectorizeWithUnknownRank) {
auto map_node =
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
auto batch_node = AddBatchNode(&graph, map_node->name());
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
CheckNotVectorized(output, map_node->op(), batch_node->op(),
input_node->name());
}
......@@ -474,9 +561,8 @@ TEST(MapVectorizationTest, VectorizeWithUnknownDim) {
auto map_node =
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
auto batch_node = AddBatchNode(&graph, map_node->name());
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
CheckNotVectorized(output, map_node->op(), batch_node->op(),
input_node->name());
}
......@@ -493,10 +579,9 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
auto map_node =
AddMapNode(&graph, input_node->name(), map_fn->signature().name());
auto batch_node = AddBatchNode(&graph, map_node->name());
MapVectorization optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
CheckVectorized(
TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
CheckVectorizedWithChooseFastest(
output, /*expected_vectorized_branch=*/{batch_node->op(), map_node->op()},
/*expected_original_branch=*/{map_node->op(), batch_node->op()},
input_node->name());
......
......@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/meta_optimizer.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
......@@ -29,6 +30,50 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
using ConfigMap =
std::map<string, tensorflow::RewriterConfig_CustomGraphOptimizer>;
// Parses a list of string optimizer configurations into a map from
// optimizer name -> rewriter config for that optimizer.
Status ToConfigMap(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config,
ConfigMap* result) {
auto found = gtl::FindOrNull(config->parameter_map(), "optimizer_configs");
if (!found) return Status::OK();
auto& options = found->list().s();
for (const auto& option_string : options) {
// The option string has the format
// <optimizer_name>:<config_key>:<config_value>
std::vector<string> split = absl::StrSplit(option_string, ':');
if (split.size() != 3) {
return errors::Internal(
"Wrong format for optimizer options. Expect <optimizer name>:<config "
"key>:<config value>, received: ",
option_string);
}
const string& optimizer_name = split[0];
const string& config_key = split[1];
const string& config_value = split[2];
auto optimizer_config = gtl::FindOrNull(*result, optimizer_name);
if (!optimizer_config) {
(*result)[optimizer_name] =
tensorflow::RewriterConfig_CustomGraphOptimizer();
optimizer_config = gtl::FindOrNull(*result, optimizer_name);
}
(*optimizer_config->mutable_parameter_map())[config_key].set_s(
config_value);
}
return Status::OK();
}
} // namespace
Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
// Stores the optimized item so far.
......@@ -86,13 +131,16 @@ Status TFDataMetaOptimizer::Init(
// Initialize custom tf.data optimizers based on config.
auto& optimizers = config->parameter_map().at("optimizers").list().s();
ConfigMap optimizer_configs;
TF_RETURN_IF_ERROR(ToConfigMap(config, &optimizer_configs));
for (const auto& optimizer_name : optimizers) {
auto optimizer =
CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
if (optimizer) {
// None of our data optimizers implement a meaningful Init function.
// This returns an error in case any of them does.
TF_RETURN_IF_ERROR(optimizer->Init());
TF_RETURN_IF_ERROR(
optimizer->Init(gtl::FindOrNull(optimizer_configs, optimizer_name)));
enabled_optimizers_[optimizer_name] = std::move(optimizer);
} else {
// This should never happen.
......
......@@ -36,6 +36,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("optimization_configs", &optimizer_configs_));
}
protected:
......@@ -44,8 +46,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
std::vector<string> optimizations;
OP_REQUIRES_OK(
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
Dataset* dataset =
new Dataset(ctx, input, optimizations, output_types_, output_shapes_);
Dataset* dataset = new Dataset(ctx, input, optimizations, output_types_,
output_shapes_, optimizer_configs_);
Status s = dataset->Optimize(ctx);
if (s.ok()) {
*output = dataset;
......@@ -61,9 +63,11 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const std::vector<string>& optimizations,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
const std::vector<PartialTensorShape>& output_shapes,
const std::vector<string>& optimizer_configs)
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
optimizations_(optimizations) {}
optimizations_(optimizations),
optimizer_configs_(optimizer_configs) {}
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
......@@ -81,15 +85,23 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
for (const auto& opt : optimizations_) {
custom_optimizations_list->add_s(opt);
}
auto* config_list =
(*custom_optimizer->mutable_parameter_map())["optimizer_configs"]
.mutable_list();
for (const auto& config : optimizer_configs_) {
config_list->add_s(config);
}
return rewriter_config;
}
const std::vector<string> optimizations_;
const std::vector<string> optimizer_configs_;
};
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
std::vector<string> optimizer_configs_;
};
REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
......
......@@ -624,6 +624,7 @@ REGISTER_OP("OptimizeDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("optimization_configs: list(string) = []")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("OptionalFromValue")
......
......@@ -26,6 +26,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@CheckpointInputPipelineHook
@@CsvDataset
@@DatasetStructure
@@MapVectorizationOptions
@@NestedStructure
@@OptimizationOptions
@@Optional
......@@ -102,6 +103,7 @@ from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_d
from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
from tensorflow.python.data.experimental.ops.optimization_options import MapVectorizationOptions
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.experimental.ops.parsing_ops import parse_example_dataset
from tensorflow.python.data.experimental.ops.prefetching_ops import copy_to_device
......
......@@ -321,6 +321,13 @@ def _generate_optimization_test_cases():
@test_util.run_all_in_graph_and_eager_modes
class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def _enable_map_vectorization(self, dataset, use_choose=True):
options = dataset_ops.Options()
opt_options = options.experimental_optimization
opt_options.map_vectorization.enabled = True
opt_options.map_vectorization.use_choose_fastest = use_choose
return dataset.with_options(options)
def _get_test_datasets(self,
base_dataset,
map_fn,
......@@ -361,10 +368,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
# to verify the optimization result.
optimized = _make_dataset(["ChooseFastestBranch"]
if expect_optimized else [map_node_name, "Batch"])
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.map_vectorization = True
optimized = optimized.with_options(options)
optimized = self._enable_map_vectorization(optimized)
return unoptimized, optimized
@parameterized.named_parameters(_generate_optimization_test_cases())
......@@ -404,16 +408,12 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def testOptimizationWithMapAndBatchFusion(self):
# Tests that vectorization works on fused map and batch.
y = constant_op.constant(1, shape=(2,))
z = constant_op.constant(2, shape=(2,))
def map_fn(x):
return x, y, z
return x**2
base_dataset = dataset_ops.Dataset.range(1000)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
[3, 4]]).repeat(5)
base_dataset = base_dataset.with_options(options)
def _make_dataset(node_names):
......@@ -423,9 +423,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
unoptimized = _make_dataset(["MapAndBatch"])
optimized = _make_dataset(["ChooseFastestBranch"])
options = dataset_ops.Options()
options.experimental_optimization.map_vectorization = True
optimized = optimized.with_options(options)
optimized = self._enable_map_vectorization(optimized)
self.assertDatasetsEqual(optimized, unoptimized)
@parameterized.named_parameters(
......@@ -474,10 +472,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
unoptimized = make_dataset(unoptimized_seq)
optimized = make_dataset(["ChooseFastestBranch", "ChooseFastestBranch"])
options = dataset_ops.Options()
options.experimental_optimization.map_vectorization = True
optimized = optimized.with_options(options)
optimized = self._enable_map_vectorization(optimized)
self.assertDatasetsEqual(optimized, unoptimized)
def testOptimizationIgnoreStateful(self):
......@@ -536,9 +531,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
options.experimental_optimization.apply_default_optimizations = False
unoptimized = unoptimized.with_options(options)
options = dataset_ops.Options()
options.experimental_optimization.map_vectorization = True
optimized = unoptimized.with_options(options)
optimized = self._enable_map_vectorization(unoptimized)
self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationWithSparseTensor(self):
......@@ -554,10 +547,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
unoptimized = unoptimized.with_options(options)
options = dataset_ops.Options()
options.experimental_optimization.map_vectorization = True
optimized = unoptimized.with_options(options)
optimized = self._enable_map_vectorization(unoptimized)
self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationWithPrefetch(self):
......@@ -565,11 +555,16 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.map(lambda x: x)
dataset = dataset.prefetch(1)
dataset = dataset.batch(10)
options = dataset_ops.Options()
options.experimental_optimization.map_vectorization = True
dataset = dataset.with_options(options)
dataset = self._enable_map_vectorization(dataset)
self.assertDatasetProduces(dataset, [list(range(10))])
def testOptimizationWithoutChooseFastest(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.batch(10)
dataset = self._enable_map_vectorization(dataset, use_choose=False)
self.assertDatasetProduces(dataset, [[x**2 for x in range(10)]])
if __name__ == "__main__":
test.main()
......@@ -17,11 +17,42 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.MapVectorizationOptions")
class MapVectorizationOptions(options.OptionsBase):
"""Represents options for the MapVectorization optimization."""
# TODO(rachelim): Other configuration parameters can go here, for example,
# how many "experiments" to run with ChooseFastestBranchDataset.
enabled = options.create_option(
name="enabled",
ty=bool,
docstring=
"Whether to vectorize map transformations. If None, defaults to False."
)
use_choose_fastest = options.create_option(
name="use_choose_fastest",
ty=bool,
docstring="Whether to use ChooseFastestBranchDataset with this "
"transformation. If True, the pipeline picks between the vectorized and "
"original segment at runtime based on their iterations speed. If None, "
"defaults to False.")
def _static_optimizations(self):
if self.enabled:
return ["map_vectorization"]
return []
def _static_optimization_configs(self):
if self.use_choose_fastest:
return ["map_vectorization:use_choose_fastest:true"]
else:
return ["map_vectorization:use_choose_fastest:false"]
@tf_export("data.experimental.OptimizationOptions")
class OptimizationOptions(options.OptionsBase):
"""Represents options for dataset optimizations.
......@@ -102,9 +133,11 @@ class OptimizationOptions(options.OptionsBase):
map_vectorization = options.create_option(
name="map_vectorization",
ty=bool,
ty=MapVectorizationOptions,
docstring=
"Whether to vectorize map transformations. If None, defaults to False.")
"The map vectorization options associated with the dataset. See "
"`tf.data.experimental.MapVectorizationOptions` for more details.",
default_factory=MapVectorizationOptions)
noop_elimination = options.create_option(
name="noop_elimination",
......@@ -128,7 +161,6 @@ class OptimizationOptions(options.OptionsBase):
"map_and_filter_fusion",
"map_parallelization",
"map_fusion",
"map_vectorization",
"noop_elimination",
"shuffle_and_repeat_fusion",
]
......@@ -147,4 +179,12 @@ class OptimizationOptions(options.OptionsBase):
for optimization in optimizations_to_disable:
if getattr(self, optimization) is not False:
result.add(optimization)
if self.map_vectorization is not None:
result.update(self.map_vectorization._static_optimizations()) # pylint: disable=protected-access
return sorted(list(result))
def _static_optimization_configs(self):
if self.map_vectorization is not None:
return self.map_vectorization._static_optimization_configs() # pylint: disable=protected-access
return []
......@@ -191,7 +191,8 @@ class DatasetV2(object):
"`tf.enable_resource_variables()` at the start of the program." %
", ".join(static_optimizations))
else:
dataset = _OptimizeDataset(dataset, static_optimizations)
dataset = _OptimizeDataset(dataset, static_optimizations,
options._static_optimization_configs()) # pylint: disable=protected-access
autotune = True
cpu_budget = 0 # Indicates that all CPU cores should be used.
......@@ -2009,6 +2010,10 @@ class Options(options_lib.OptionsBase):
result.append("latency_all_edges")
return result
def _static_optimization_configs(self):
"""Produces the list of configurations for enabled static optimizations."""
return self.experimental_optimization._static_optimization_configs() # pylint: disable=protected-access
def merge(self, options):
"""Merges itself with the given `tf.data.Options`.
......@@ -3295,15 +3300,18 @@ class _ModelDataset(UnaryUnchangedStructureDataset):
class _OptimizeDataset(UnaryUnchangedStructureDataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
def __init__(self, input_dataset, optimizations):
def __init__(self, input_dataset, optimizations, optimization_configs=None):
self._input_dataset = input_dataset
if optimizations is None:
optimizations = []
if optimization_configs is None:
optimization_configs = []
self._optimizations = ops.convert_to_tensor(
optimizations, dtype=dtypes.string, name="optimizations")
variant_tensor = gen_dataset_ops.optimize_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._optimizations,
optimization_configs=optimization_configs,
**flat_structure(self))
super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
......
path: "tensorflow.data.experimental.MapVectorizationOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.optimization_options.MapVectorizationOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "enabled"
mtype: "<type \'property\'>"
}
member {
name: "use_choose_fastest"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
......@@ -20,6 +20,10 @@ tf_module {
name: "INFINITE_CARDINALITY"
mtype: "<type \'int\'>"
}
member {
name: "MapVectorizationOptions"
mtype: "<type \'type\'>"
}
member {
name: "NestedStructure"
mtype: "<type \'type\'>"
......
......@@ -2182,7 +2182,7 @@ tf_module {
}
member_method {
name: "OptimizeDataset"
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'optimization_configs\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
}
member_method {
name: "OptionalFromValue"
......
path: "tensorflow.data.experimental.MapVectorizationOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.optimization_options.MapVectorizationOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "enabled"
mtype: "<type \'property\'>"
}
member {
name: "use_choose_fastest"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
......@@ -20,6 +20,10 @@ tf_module {
name: "INFINITE_CARDINALITY"
mtype: "<type \'int\'>"
}
member {
name: "MapVectorizationOptions"
mtype: "<type \'type\'>"
}
member {
name: "NestedStructure"
mtype: "<type \'type\'>"
......
......@@ -2182,7 +2182,7 @@ tf_module {
}
member_method {
name: "OptimizeDataset"
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'optimization_configs\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
}
member_method {
name: "OptionalFromValue"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册