提交 b7fede02 编写于 作者: P Pete Warden 提交者: TensorFlower Gardener

Switched existing backporting rules to use the graph transform framework

Change: 150214498
上级 d0acae66
......@@ -59,6 +59,7 @@ cc_library(
name = "transforms_lib",
srcs = [
"add_default_attributes.cc",
"backports.cc",
"fold_batch_norms.cc",
"fold_constants_lib.cc",
"fold_old_batch_norms.cc",
......@@ -106,6 +107,7 @@ tf_cc_test(
size = "small",
srcs = [
"add_default_attributes_test.cc",
"backports_test.cc",
"fold_batch_norms_test.cc",
"fold_constants_test.cc",
"fold_old_batch_norms_test.cc",
......
......@@ -13,6 +13,7 @@
* [Eight-bit Calculations](#eight-bit-calculations)
* [Transform Reference](#transform-reference)
* [add_default_attributes](#add_default_attributes)
* [backport_concatv2](#backport_concatv2)
* [fold_batch_norms](#fold_batch_norms)
* [fold_constants](#fold_constants)
* [fold_old_batch_norms](#fold_old_batch_norms)
......@@ -336,6 +337,15 @@ can be useful to run this update process as a transform. This process finds any
op attributes that are defined in the current TensorFlow list of ops but not
within the saved model, and sets them to the defined default for that attribute.
### backport_concatv2
Args: None
If you have a GraphDef file that has been produced by a newer version of the
TensorFlow framework and includes ConcatV2, and you want to run it on an older
version that only supports Concat, this transform will take care of converting
those newer ops to the equivalent older form.
### fold_batch_norms
Args: None \
......
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Switch any ConcatV2 nodes to the v1 version, swapping the input order.
Status BackportConcatV2Transform(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, {"ConcatV2"},
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& concat_v2_node = match.node;
NodeDef concat_node = concat_v2_node;
concat_node.set_op("Concat");
// The last input is inserted at the head of the inputs, because Concat
// expects the dimension as the first input (not the last as in
// ConcatV2).
concat_node.mutable_input()->Clear();
const string& dim_input =
concat_v2_node.input(concat_v2_node.input_size() - 1);
concat_node.add_input(dim_input);
for (int i = 0; i < (concat_v2_node.input_size() - 1); ++i) {
concat_node.add_input(concat_v2_node.input(i));
}
// Tidx attribute must be deleted because it's not used in Concat.
concat_node.mutable_attr()->erase("Tidx");
new_nodes->push_back(concat_node);
return Status::OK();
},
{true}, output_graph_def));
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("backport_concatv2", BackportConcatV2Transform);
} // namespace graph_transforms
} // namespace tensorflow
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status BackportConcatV2Transform(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class BackportConcatV2Test : public ::testing::Test {
protected:
void TestBackportConcatV2() {
GraphDef graph_def;
NodeDef* mul_node1 = graph_def.add_node();
mul_node1->set_name("mul_node1");
mul_node1->set_op("Mul");
mul_node1->add_input("add_node2");
mul_node1->add_input("add_node3");
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node1");
add_node3->add_input("const_node3");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* concat_node = graph_def.add_node();
concat_node->set_name("concat_node");
concat_node->set_op("ConcatV2");
concat_node->add_input("const_node1");
concat_node->add_input("const_node2");
concat_node->add_input("const_node3");
SetNodeAttr("Tidx", DT_INT32, concat_node);
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"concat_node"};
TF_ASSERT_OK(BackportConcatV2Transform(graph_def, context, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("concat_node"));
EXPECT_EQ("Concat", node_lookup.at("concat_node")->op());
EXPECT_EQ(0, node_lookup.at("concat_node")->attr().count("Tidx"));
EXPECT_EQ("const_node3", node_lookup.at("concat_node")->input(0));
EXPECT_EQ("const_node1", node_lookup.at("concat_node")->input(1));
EXPECT_EQ("const_node2", node_lookup.at("concat_node")->input(2));
EXPECT_EQ(1, node_lookup.count("const_node1"));
EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
EXPECT_EQ(1, node_lookup.count("const_node2"));
EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
EXPECT_EQ(1, node_lookup.count("const_node3"));
EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
}
};
TEST_F(BackportConcatV2Test, TestBackportConcatV2) { TestBackportConcatV2(); }
} // namespace graph_transforms
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册