prune_test.cc 5.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yang Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yang Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yang Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yang Yang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/prune.h"
Y
Yang Yang 已提交
16

17
#include <gtest/gtest.h>
18
#include <set>
19 20
#include <string>

Y
Yi Wang 已提交
21 22
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/operator.h"
Y
Yang Yang 已提交
23

Y
Yi Wang 已提交
24 25 26
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
Y
Yang Yang 已提交
27 28 29 30 31

namespace f = paddle::framework;

void AddOp(const std::string &type, const f::VariableNameMap &inputs,
           const f::VariableNameMap &outputs, f::AttributeMap attrs,
Y
Yu Yang 已提交
32
           paddle::framework::BlockDesc *block) {
Y
Yang Yang 已提交
33 34 35
  // insert output
  for (auto kv : outputs) {
    for (auto v : kv.second) {
Y
Yang Yang 已提交
36
      auto var = block->Var(v);
37
      var->SetDataType(paddle::framework::proto::VarType::FP32);
Y
Yang Yang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
    }
  }

  // insert op
  auto op = block->AppendOp();
  op->SetType(type);
  for (auto &kv : inputs) {
    op->SetInput(kv.first, kv.second);
  }
  for (auto &kv : outputs) {
    op->SetOutput(kv.first, kv.second);
  }
  op->SetAttrMap(attrs);
}

TEST(Prune, one_operator) {
Y
Yu Yang 已提交
54 55
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
Y
Yang Yang 已提交
56

Y
Yiqun Liu 已提交
57 58
  AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{},
        block);
Y
Yang Yang 已提交
59

60 61
  f::proto::ProgramDesc *pdesc = program.Proto();
  f::proto::ProgramDesc pruned;
62 63
  std::set<std::string> feed_var_names = {};
  f::Prune(*pdesc, feed_var_names, &pruned);
Y
Yang Yang 已提交
64 65
  PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);

66
  feed_var_names.insert("a");
Y
Yang Yang 已提交
67
  pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
68
  f::Prune(*pdesc, feed_var_names, &pruned);
Y
Yang Yang 已提交
69 70 71
  PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
}

Y
Yang Yang 已提交
72
TEST(Prune, forward) {
Y
Yu Yang 已提交
73 74
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
Y
Yang Yang 已提交
75

Y
Yiqun Liu 已提交
76 77 78 79 80 81 82 83
  AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{},
        block);
  AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, f::AttributeMap{},
        block);
  AddOp("one_one", {{"input", {"c"}}}, {{"output", {"d"}}}, f::AttributeMap{},
        block);
  AddOp("one_one", {{"input", {"d"}}}, {{"output", {"e"}}}, f::AttributeMap{},
        block);
Y
Yang Yang 已提交
84

85
  f::proto::ProgramDesc *pdesc = program.Proto();
86
  std::set<std::string> feed_var_names = {"a"};
Y
Yang Yang 已提交
87
  for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
88
    f::proto::ProgramDesc pruned;
Y
Yang Yang 已提交
89
    pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
90
    f::Prune(*pdesc, feed_var_names, &pruned);
Y
Yang Yang 已提交
91 92 93 94 95
    PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
  }
}

TEST(Prune, multi_input_op) {
Y
Yu Yang 已提交
96 97
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
Y
Yang Yang 已提交
98

Y
Yiqun Liu 已提交
99 100 101
  AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, f::AttributeMap{},
        block);
  AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, f::AttributeMap{},
Y
Yang Yang 已提交
102
        block);
Y
Yiqun Liu 已提交
103 104 105 106
  AddOp("one_one", {{"input", {"a2"}}}, {{"output", {"b2"}}}, f::AttributeMap{},
        block);
  AddOp("three_one", {{"input", {"b0", "b1", "b2"}}}, {{"output", {"c"}}},
        f::AttributeMap{}, block);
Y
Yang Yang 已提交
107

108
  f::proto::ProgramDesc *pdesc = program.Proto();
Y
Yang Yang 已提交
109 110
  pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);

111
  f::proto::ProgramDesc pruned;
112 113
  std::set<std::string> feed_var_names = {"a0", "a1", "a2"};
  f::Prune(*pdesc, feed_var_names, &pruned);
Y
Yang Yang 已提交
114 115 116 117
  PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
}

TEST(Prune, multi_output_op) {
Y
Yu Yang 已提交
118 119
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
Y
Yang Yang 已提交
120

Y
Yiqun Liu 已提交
121 122 123 124 125 126
  AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
        f::AttributeMap{}, block);
  AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, f::AttributeMap{},
        block);
  AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, f::AttributeMap{},
        block);
Y
Yang Yang 已提交
127

128
  f::proto::ProgramDesc *pdesc = program.Proto();
Y
Yang Yang 已提交
129 130
  pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);

131
  f::proto::ProgramDesc pruned;
132 133
  std::set<std::string> feed_var_names = {"a"};
  f::Prune(*pdesc, feed_var_names, &pruned);
Y
Yang Yang 已提交
134 135
  PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
}
Y
Yang Yang 已提交
136 137

TEST(Prune, multi_target) {
Y
Yu Yang 已提交
138 139
  f::ProgramDesc program;
  f::BlockDesc *block = program.MutableBlock(0);
Y
Yang Yang 已提交
140

Y
Yiqun Liu 已提交
141 142 143 144 145 146
  AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
        f::AttributeMap{}, block);
  AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, f::AttributeMap{},
        block);
  AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, f::AttributeMap{},
        block);
Y
Yang Yang 已提交
147

148
  f::proto::ProgramDesc *pdesc = program.Proto();
Y
Yang Yang 已提交
149 150 151
  pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);
  pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);

152
  f::proto::ProgramDesc pruned;
153 154
  std::set<std::string> feed_var_names = {"a"};
  f::Prune(*pdesc, feed_var_names, &pruned);
Y
Yang Yang 已提交
155 156
  PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
}