提交 76e5985f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[Grappler] Improve removal of redundant ops in arithmetic optimizer.

1. Remove identity Reshape ops with control inputs.
2. Remove identity BroadcastTo ops. This op is common, as it appears in the gradient of Sum.

PiperOrigin-RevId: 318163412
Change-Id: I63954d71c4ed4ff8636bdad651ac7f18d0518fc0
上级 7f236df1
......@@ -117,6 +117,8 @@ bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
bool IsBroadcastTo(const NodeDef& node) { return node.op() == "BroadcastTo"; }
bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
bool IsCastLike(const NodeDef& node) {
......
......@@ -47,6 +47,7 @@ bool IsBetainc(const NodeDef& node);
bool IsBiasAdd(const NodeDef& node);
bool IsBiasAddGrad(const NodeDef& node);
bool IsBitcast(const NodeDef& node);
bool IsBroadcastTo(const NodeDef& node);
bool IsCast(const NodeDef& node);
bool IsCheckNumerics(const NodeDef& node);
bool IsCollective(const NodeDef& node);
......
......@@ -1917,15 +1917,22 @@ class LogSoftmaxStage : public ArithmeticOptimizerStage {
// ^ |
// | |
// input input ---+
class RemoveRedundantReshape : public ArithmeticOptimizerStage {
//
// Additionally, Reshape and BroadcastTo nodes where the
// input and target shapes are equal are bypassed.
//
class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage {
public:
explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx,
const ArithmeticOptimizerContext& ctx_ext)
: ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {}
~RemoveRedundantReshape() override = default;
explicit RemoveRedundantReshapeOrBroadcastTo(
const GraphOptimizerContext& ctx,
const ArithmeticOptimizerContext& ctx_ext)
: ArithmeticOptimizerStage("RemoveRedundantReshapeOrBroadcastTo", ctx,
ctx_ext) {}
~RemoveRedundantReshapeOrBroadcastTo() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsReshape(*node) && !IsInPreserveSet(*node);
return (IsReshape(*node) || IsBroadcastTo(*node)) &&
!IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
......@@ -1933,7 +1940,8 @@ class RemoveRedundantReshape : public ArithmeticOptimizerStage {
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
// 1. Bypass reshape followed by reshape.
if (IsReshape(*input) && !HasControlInputs(*input)) {
if (IsReshape(*node) && IsReshape(*input)) {
ForwardControlDependencies(node, {input});
node->set_input(0, input->input(0));
ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0));
*simplified_node_name = node->name();
......@@ -1944,7 +1952,7 @@ class RemoveRedundantReshape : public ArithmeticOptimizerStage {
// 2. If the reshape is a no-op, forward its input to its consumers, unless
// it anchors a control dependency since we want to make sure that control
// dependency is triggered.
if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) {
if (InputMatchesTargetShape(*node) && !HasControlInputs(*node)) {
*simplified_node_name = node->input(0);
return Status::OK();
}
......@@ -1954,7 +1962,7 @@ class RemoveRedundantReshape : public ArithmeticOptimizerStage {
private:
// Returns whether `reshape` is an identity op.
bool ReshapeIsIdentity(const NodeDef& reshape) {
bool InputMatchesTargetShape(const NodeDef& reshape) {
const OpInfo::TensorProperties* reshape_props;
const OpInfo::TensorProperties* input_props;
......@@ -3673,7 +3681,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.remove_redundant_cast)
pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
if (options_.remove_redundant_reshape)
pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
pipeline.AddStage<RemoveRedundantReshapeOrBroadcastTo>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
if (options_.replace_mul_with_square)
......
......@@ -828,37 +828,45 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
Output inputs_shape = ops::Shape(s, inputs);
// The target shape of the reshape is the concatenation of `batch_size` and
// [3,28,28].
Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
ops::Const(s, {1}, {1}));
Output target_shape = ops::Concat(
s.WithOpName("target_shape"),
{batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
Output reshape = ops::Reshape(s, inputs, target_shape);
Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
for (bool is_broadcastto : {false, true}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
Output inputs_shape = ops::Shape(s, inputs);
// The target shape of the reshape is the concatenation of `batch_size` and
// [3,28,28].
Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
ops::Const(s, {1}, {1}));
Output target_shape = ops::Concat(
s.WithOpName("target_shape"),
{batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
if (is_broadcastto) {
Output outputs = ops::Identity(s.WithOpName("outputs"),
ops::BroadcastTo(s, inputs, target_shape));
} else {
Output outputs = ops::Identity(s.WithOpName("outputs"),
ops::Reshape(s, inputs, target_shape));
}
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
auto tensors_expected =
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
ASSERT_EQ(tensors_expected.size(), 1);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
auto tensors_expected =
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
ASSERT_EQ(tensors_expected.size(), 1);
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantReshape(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantReshape(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
EXPECT_EQ(CountOpNodes(output, "BroadcastTo"), 0);
auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
}
TEST_F(ArithmeticOptimizerTest,
......@@ -1023,7 +1031,9 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeCombineReshapes) {
ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
Output nhwc = ops::Reshape(
s.WithOpName("nhwc"), transpose,
ops::Const(s.WithOpName("nhwc_shape"), {8, 28, 28, 12}, {4}));
ops::Const(
s.WithControlDependencies(nchw_vect_c).WithOpName("nhwc_shape"),
{8, 28, 28, 12}, {4}));
Output flatten = ops::Reshape(
s.WithOpName("flatten"), nhwc,
ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册