op_registry.cc 4.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/operators/ops_extra_info.h"
17

18 19
#include "glog/logging.h"

Y
Yi Wang 已提交
20
namespace paddle {
21 22
namespace framework {

Y
Yu Yang 已提交
23
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
24 25 26 27
    const std::string& type,
    const VariableNameMap& inputs,
    const VariableNameMap& outputs,
    const AttributeMap& attrs,
28
    bool attr_check) {
29 30 31 32 33 34 35 36 37 38 39 40
  AttributeMap standard_attrs;
  AttributeMap runtime_attrs =
      paddle::operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(type);
  for (auto& attr : attrs) {
    auto it = runtime_attrs.find(attr.first);
    if (it != runtime_attrs.end()) {
      it->second = attr.second;
    } else {
      standard_attrs[attr.first] = attr.second;
    }
  }
  auto& info = OpInfoMap::Instance().Get(type);
41 42 43 44 45 46 47 48 49 50 51
  if (attr_check) {
    if (info.Checker() != nullptr) {
      info.Checker()->Check(&standard_attrs);
    }
    const auto& extra_attr_checkers =
        operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(type);
    if (!extra_attr_checkers.empty()) {
      for (const auto& checker : extra_attr_checkers) {
        checker(&runtime_attrs, false);
      }
    }
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
  }
  auto op_base = std::unique_ptr<OperatorBase>(
      info.Creator()(type, inputs, outputs, standard_attrs));
  op_base->SetRuntimeAttributeMap(runtime_attrs);
  return op_base;
}

std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
    const std::string& type,
    const VariableNameMap& inputs,
    const VariableNameMap& outputs,
    const AttributeMap& attrs,
    const AttributeMap& runtime_attrs,
    bool attr_check) {
  std::unique_ptr<OperatorBase> op_base;
Y
Yu Yang 已提交
67
  auto& info = OpInfoMap::Instance().Get(type);
H
hong 已提交
68
  if (attr_check && info.Checker() != nullptr) {
69 70
    auto tmp_attrs = attrs;
    info.Checker()->Check(&tmp_attrs);
71
    op_base = std::unique_ptr<OperatorBase>(
72
        info.Creator()(type, inputs, outputs, tmp_attrs));
73 74 75
  } else {
    op_base = std::unique_ptr<OperatorBase>(
        info.Creator()(type, inputs, outputs, attrs));
Y
Stash  
Yu Yang 已提交
76
  }
77 78 79 80 81 82 83 84 85
  const auto& extra_attr_checkers =
      operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(type);
  if (!extra_attr_checkers.empty()) {
    auto op_runtime_attr_map = runtime_attrs;
    for (const auto& checker : extra_attr_checkers) {
      checker(&op_runtime_attr_map, false);
    }
    op_base->SetRuntimeAttributeMap(op_runtime_attr_map);
  }
86
  return op_base;
87 88
}

Y
Yu Yang 已提交
89
static VariableNameMap ConvertOpDescVarsToVarNameMap(
90 91
    const google::protobuf::RepeatedPtrField<proto::OpDesc::Var>&
        op_desc_vars) {
Y
Yu Yang 已提交
92
  VariableNameMap ret_val;
93 94 95 96
  for (auto& var : op_desc_vars) {
    auto& var_names = ret_val[var.parameter()];
    auto& var_names_in_proto = var.arguments();
    var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
97 98
    std::copy(var_names_in_proto.begin(),
              var_names_in_proto.end(),
99 100 101 102 103
              std::back_inserter(var_names));
  }
  return ret_val;
}

104 105
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
    const proto::OpDesc& op_desc) {
M
minqiyang 已提交
106 107 108
  VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
             "used in unit tests. Use CreateOp(const OpDesc& op_desc) "
             "instead.";
Y
Yu Yang 已提交
109 110 111
  VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
  VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
  AttributeMap attrs;
112 113 114
  AttributeMap extra_attrs =
      paddle::operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(
          op_desc.type());
Y
Yu Yang 已提交
115
  for (auto& attr : op_desc.attrs()) {
116 117 118 119 120 121
    auto it = extra_attrs.find(attr.name());
    if (it != extra_attrs.end()) {
      it->second = GetAttrValue(attr);
    } else {
      attrs[attr.name()] = GetAttrValue(attr);
    }
Y
Yu Yang 已提交
122 123
  }

124
  return CreateOp(op_desc.type(), inputs, outputs, attrs, extra_attrs);
Y
Yu Yang 已提交
125 126
}

Y
Yu Yang 已提交
127
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
128 129 130
  return CreateOp(op_desc.Type(),
                  op_desc.Inputs(),
                  op_desc.Outputs(),
131 132
                  op_desc.GetAttrMap(),
                  op_desc.GetRuntimeAttrMap());
133 134 135
}

}  // namespace framework
L
liaogang 已提交
136
}  // namespace paddle