From b7fede02c1f8232ebc4ff035cfcf63772c64c741 Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Wed, 15 Mar 2017 10:06:30 -0800 Subject: [PATCH] Switched existing backporting rules to use the graph transform framework Change: 150214498 --- tensorflow/tools/graph_transforms/BUILD | 2 + tensorflow/tools/graph_transforms/README.md | 10 ++ .../tools/graph_transforms/backports.cc | 65 +++++++++++ .../tools/graph_transforms/backports_test.cc | 105 ++++++++++++++++++ 4 files changed, 182 insertions(+) create mode 100644 tensorflow/tools/graph_transforms/backports.cc create mode 100644 tensorflow/tools/graph_transforms/backports_test.cc diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index fca411337a9..ee6d013cc0c 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -59,6 +59,7 @@ cc_library( name = "transforms_lib", srcs = [ "add_default_attributes.cc", + "backports.cc", "fold_batch_norms.cc", "fold_constants_lib.cc", "fold_old_batch_norms.cc", @@ -106,6 +107,7 @@ tf_cc_test( size = "small", srcs = [ "add_default_attributes_test.cc", + "backports_test.cc", "fold_batch_norms_test.cc", "fold_constants_test.cc", "fold_old_batch_norms_test.cc", diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index b036084207f..6597adb68a0 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -13,6 +13,7 @@ * [Eight-bit Calculations](#eight-bit-calculations) * [Transform Reference](#transform-reference) * [add_default_attributes](#add_default_attributes) + * [backport_concatv2](#backport_concatv2) * [fold_batch_norms](#fold_batch_norms) * [fold_constants](#fold_constants) * [fold_old_batch_norms](#fold_old_batch_norms) @@ -336,6 +337,15 @@ can be useful to run this update process as a transform. This process finds any op attributes that are defined in the current TensorFlow list of ops but not within the saved model, and sets them to the defined default for that attribute. +### backport_concatv2 + +Args: None + +If you have a GraphDef file that has been produced by a newer version of the +TensorFlow framework and includes ConcatV2, and you want to run it on an older +version that only supports Concat, this transform will take care of converting +those newer ops to the equivalent older form. + ### fold_batch_norms Args: None \ diff --git a/tensorflow/tools/graph_transforms/backports.cc b/tensorflow/tools/graph_transforms/backports.cc new file mode 100644 index 00000000000..3b1d57146bb --- /dev/null +++ b/tensorflow/tools/graph_transforms/backports.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Switch any ConcatV2 nodes to the v1 version, swapping the input order. +Status BackportConcatV2Transform(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, {"ConcatV2"}, + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& concat_v2_node = match.node; + NodeDef concat_node = concat_v2_node; + concat_node.set_op("Concat"); + // The last input is inserted at the head of the inputs, because Concat + // expects the dimension as the first input (not the last as in + // ConcatV2). + concat_node.mutable_input()->Clear(); + const string& dim_input = + concat_v2_node.input(concat_v2_node.input_size() - 1); + concat_node.add_input(dim_input); + for (int i = 0; i < (concat_v2_node.input_size() - 1); ++i) { + concat_node.add_input(concat_v2_node.input(i)); + } + // Tidx attribute must be deleted because it's not used in Concat. + concat_node.mutable_attr()->erase("Tidx"); + new_nodes->push_back(concat_node); + return Status::OK(); + }, + {true}, output_graph_def)); + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("backport_concatv2", BackportConcatV2Transform); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/backports_test.cc b/tensorflow/tools/graph_transforms/backports_test.cc new file mode 100644 index 00000000000..021cb14136a --- /dev/null +++ b/tensorflow/tools/graph_transforms/backports_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status BackportConcatV2Transform(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class BackportConcatV2Test : public ::testing::Test { + protected: + void TestBackportConcatV2() { + GraphDef graph_def; + + NodeDef* mul_node1 = graph_def.add_node(); + mul_node1->set_name("mul_node1"); + mul_node1->set_op("Mul"); + mul_node1->add_input("add_node2"); + mul_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node1"); + add_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* concat_node = graph_def.add_node(); + concat_node->set_name("concat_node"); + concat_node->set_op("ConcatV2"); + concat_node->add_input("const_node1"); + concat_node->add_input("const_node2"); + concat_node->add_input("const_node3"); + SetNodeAttr("Tidx", DT_INT32, concat_node); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"concat_node"}; + TF_ASSERT_OK(BackportConcatV2Transform(graph_def, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("concat_node")); + EXPECT_EQ("Concat", node_lookup.at("concat_node")->op()); + EXPECT_EQ(0, node_lookup.at("concat_node")->attr().count("Tidx")); + EXPECT_EQ("const_node3", node_lookup.at("concat_node")->input(0)); + EXPECT_EQ("const_node1", node_lookup.at("concat_node")->input(1)); + EXPECT_EQ("const_node2", node_lookup.at("concat_node")->input(2)); + EXPECT_EQ(1, node_lookup.count("const_node1")); + EXPECT_EQ("Const", node_lookup.at("const_node1")->op()); + EXPECT_EQ(1, node_lookup.count("const_node2")); + EXPECT_EQ("Const", node_lookup.at("const_node2")->op()); + EXPECT_EQ(1, node_lookup.count("const_node3")); + EXPECT_EQ("Const", node_lookup.at("const_node3")->op()); + } +}; + +TEST_F(BackportConcatV2Test, TestBackportConcatV2) { TestBackportConcatV2(); } + +} // namespace graph_transforms +} // namespace tensorflow -- GitLab