op_proto_maker.cc 4.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
14
#include "paddle/fluid/framework/op_proto_maker.h"
15
#include <string>
Y
yuyang18 已提交
16
#include <vector>
17 18 19 20 21 22 23

namespace paddle {
namespace framework {

void OpProtoAndCheckerMaker::Validate() {
  validated_ = true;
  CheckNoDuplicatedInOutAttrs();
24
  CheckReuseVars();
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
}

OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
    const std::string& name, const std::string& comment) {
  auto* input = proto_->add_inputs();
  input->set_name(name);
  input->set_comment(comment);
  return OpProtoAndCheckerMaker::VariableBuilder{input};
}

OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
    const std::string& name, const std::string& comment) {
  auto* output = proto_->add_outputs();
  output->set_name(name);
  output->set_comment(comment);
  return OpProtoAndCheckerMaker::VariableBuilder{output};
}

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
void OpProtoAndCheckerMaker::Reuse(const std::string& name,
                                   const std::string& reused_name) {
  bool found = false;
  proto::OpProto::Var* var;

  for (auto& var : proto_->inputs()) {
    if (var.name() == reused_name) {
      found = true;
      break;
    }
  }
  PADDLE_ENFORCE(found == true,
                 "Input/Output name: %s reused_name: %s, one of them is not "
                 "exists or not matched.",
                 name, reused_name);

  found = false;
  for (int i = 0; i < proto_->outputs().size(); ++i) {
    var = proto_->mutable_outputs()->Mutable(i);
    if (var->name() == name) {
      PADDLE_ENFORCE(!var->has_reuse(),
                     "Output(%s) has been set reused var of %s", name,
                     var->reuse());
      found = true;
      var->set_reuse(reused_name);
      break;
    }
  }
  PADDLE_ENFORCE(found == true,
                 "Input/Output name: %s reused_name: %s, one of them is not "
                 "exists or not matched.",
                 name, reused_name);
}

77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
  std::unordered_set<std::string> names;
  auto checker = [&](const std::string& name) {
    PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
    names.insert(name);
  };
  for (auto& attr : proto_->attrs()) {
    checker(attr.name());
  }
  for (auto& input : proto_->inputs()) {
    checker(input.name());
  }
  for (auto& output : proto_->outputs()) {
    checker(output.name());
  }
}

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
void OpProtoAndCheckerMaker::CheckReuseVars() {
  std::unordered_set<std::string> names;
  for (auto& input : proto_->inputs()) {
    names.insert(input.name());
  }
  auto checker = [&](const std::string& name, const std::string& reused) {
    PADDLE_ENFORCE(
        names.count(reused),
        "Output [%s] reuse Input [%s], but the input is not registered.", name,
        reused);
  };
  for (auto& output : proto_->outputs()) {
    if (output.has_reuse()) {
      checker(output.name(), output.reuse());
    }
  }
}

Y
yuyang18 已提交
112 113 114 115 116 117 118 119 120 121
void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
                                        OpAttrChecker* attr_checker) {
  proto_ = proto;
  op_checker_ = attr_checker;
  Make();

  AddAttr<int>(OpRoleAttrName(), "The role of this operator")
      .InEnum(
          {static_cast<int>(OpRole::kForward),
           static_cast<int>(OpRole::kBackward),
Y
Yancey1989 已提交
122
           static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
Y
yuyang18 已提交
123 124
           static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
           static_cast<int>(OpRole::kLoss) |
Y
yuyang18 已提交
125 126 127
               static_cast<int>(OpRole::kBackward),
           static_cast<int>(OpRole::kNotSpecified)})
      .SetDefault(static_cast<int>(OpRole::kNotSpecified));
Y
yuyang18 已提交
128 129 130
  AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
                                    "Optimized for variable")
      .SetDefault({});
Y
yuyang18 已提交
131 132 133 134

  Validate();
}

135 136
}  // namespace framework
}  // namespace paddle