op_desc.cpp 1.7 KB
Newer Older
朔-望's avatar
朔-望 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
//
// Created by liuRuiLong on 2018/5/4.
//

#include "op_desc.h"

namespace paddle_mobile {
namespace framework {

OpDesc::OpDesc(const proto::OpDesc &desc) : desc_(desc) {
  for (int i = 0; i < desc_.inputs_size(); ++i) {
    const proto::OpDesc::Var &var = desc_.inputs(i);
    std::vector<std::string> &args = inputs_[var.parameter()];
    int arg_size = var.arguments_size();
    for (int j = 0; j < arg_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }

  for (int i = 0; i < desc_.outputs_size(); ++i) {
    const proto::OpDesc::Var &var = desc_.outputs(i);
    std::vector<std::string> &args = outputs_[var.parameter()];
    int arg_size = var.arguments_size();
    for (int j = 0; j < arg_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }

  for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
    std::string attr_name = attr.name();
    if (attr.type() != proto::AttrType::BLOCK) {
      attrs_[attr_name] = Attribute::GetAttrValue(attr);
      //      if (attr.type() == proto::AttrType::INT){
      //        std::cout << " attrName " << attr_name << " " <<
      //        attrs_[attr_name].Get<int>() << std::endl;
      //      }
    }
  }
}

const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
  return inputs_.find(name)->second;
}

const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
  return outputs_.find(name)->second;
}

Attribute OpDesc::GetAttr(const std::string &name) const {
  auto it = attrs_.find(name);
  return it->second;
}

const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
  return attrs_;
}

}  // namespace framework
}  // namespace paddle_mobile