operator.cc 5.2 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/operator.h"
16 17
#include <algorithm>
#include "paddle/framework/op_registry.h"
Q
Qiao Longfei 已提交
18 19 20 21

namespace paddle {
namespace framework {

Q
qijun 已提交
22
template <>
23
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
Q
qijun 已提交
24
    platform::CPUPlace, Eigen::DefaultDevice>() const {
D
dongzhihong 已提交
25
  return *device_context_->get_eigen_device<Eigen::DefaultDevice>();
Q
qijun 已提交
26 27 28 29
}

#ifndef PADDLE_ONLY_CPU
template <>
30
Eigen::GpuDevice&
31
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
D
dongzhihong 已提交
32
  return *device_context_->get_eigen_device<Eigen::GpuDevice>();
Q
qijun 已提交
33 34 35
}
#endif

36 37 38 39 40 41 42 43
static std::unordered_map<std::string, OpProto>* g_op_protos = nullptr;
std::unordered_map<std::string, OpProto>& OpProtos() {
  if (g_op_protos == nullptr) {
    g_op_protos = new std::unordered_map<std::string, OpProto>();
  }
  return *g_op_protos;
}

Y
Yan Chunwei 已提交
44
const std::string& OperatorBase::Input(const std::string& name) const {
Y
Yu Yang 已提交
45 46
  auto& ins = Inputs(name);
  PADDLE_ENFORCE_EQ(ins.size(), 1UL,
Y
Yu Yang 已提交
47 48
                    "Op %s input %s should contain only one variable", type_,
                    name);
Y
Yu Yang 已提交
49
  return ins[0];
Y
Yan Chunwei 已提交
50 51
}

Y
Yu Yang 已提交
52 53
const std::vector<std::string>& OperatorBase::Inputs(
    const std::string& name) const {
Y
Yu Yang 已提交
54 55 56 57
  auto it = inputs_.find(name);
  PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_,
                 name);
  return it->second;
Y
Yan Chunwei 已提交
58 59 60
}

const std::string& OperatorBase::Output(const std::string& name) const {
Y
Yu Yang 已提交
61 62 63
  auto& outs = Outputs(name);
  PADDLE_ENFORCE_EQ(outs.size(), 1UL,
                    "Op %s output %s should contain only one variable", type_,
Y
Yu Yang 已提交
64
                    name);
Y
Yu Yang 已提交
65
  return outs[0];
Y
Yan Chunwei 已提交
66 67
}

Y
Yu Yang 已提交
68 69
const std::vector<std::string>& OperatorBase::Outputs(
    const std::string& name) const {
Y
Yu Yang 已提交
70 71 72 73
  auto it = outputs_.find(name);
  PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_,
                 name);
  return it->second;
Y
Yan Chunwei 已提交
74 75
}

Q
Qiao Longfei 已提交
76 77
std::string OperatorBase::DebugString() const {
  std::stringstream ss;
Y
Yu Yang 已提交
78
  ss << "Op(" << type_ << "), inputs:{";
Y
Yu Yang 已提交
79 80
  for (auto it = inputs_.begin(); it != inputs_.end();) {
    auto& input = *it;
Y
Yu Yang 已提交
81 82 83 84 85 86
    ss << input.first << "[";
    for (size_t i = 0; i < input.second.size(); ++i) {
      ss << input.second[i];
      if (i != input.second.size() - 1) {
        ss << ", ";
      }
87
    }
Y
Yu Yang 已提交
88
    ss << "]";
Y
Yu Yang 已提交
89 90 91 92
    ++it;
    if (it != inputs_.end()) {
      ss << ", ";
    }
Q
Qiao Longfei 已提交
93
  }
Y
Yu Yang 已提交
94
  ss << "}, outputs:{";
Y
Yu Yang 已提交
95 96
  for (auto it = outputs_.begin(); it != outputs_.end();) {
    auto& output = *it;
Y
Yu Yang 已提交
97 98 99 100 101 102
    ss << output.first << "[";
    for (size_t i = 0; i < output.second.size(); ++i) {
      ss << output.second[i];
      if (i != output.second.size() - 1) {
        ss << ", ";
      }
103
    }
Y
Yu Yang 已提交
104
    ss << "]";
Y
Yu Yang 已提交
105 106 107 108
    ++it;
    if (it != outputs_.end()) {
      ss << ", ";
    }
Q
Qiao Longfei 已提交
109
  }
Y
Yu Yang 已提交
110
  ss << "}.";
Q
Qiao Longfei 已提交
111 112 113
  return ss.str();
}

D
dongzhihong 已提交
114 115
void OperatorBase::Rename(const std::string& old_name,
                          const std::string& new_name) {
Y
Yu Yang 已提交
116 117 118 119 120 121 122
  for (auto& input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
  for (auto& output : outputs_) {
    std::replace(output.second.begin(), output.second.end(), old_name,
                 new_name);
  }
D
dongzhihong 已提交
123 124
}

Y
Yu Yang 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
OperatorBase::OperatorBase(const std::string& type,
                           const OperatorBase::VarNameMap& inputs,
                           const OperatorBase::VarNameMap& outputs,
                           const AttributeMap& attrs)
    : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
  static std::atomic<size_t> gUniqId(0UL);
  for (auto& output : outputs_) {
    for (auto& output_name : output.second) {
      if (output_name == kTempVarName) {
        output_name += type_;
        output_name += "@";
        output_name += std::to_string(gUniqId.fetch_add(1));
      }
    }
  }
}
141

Y
Yu Yang 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
  std::vector<std::string> ret_val;
  if (has_intermediate) {
    // push all outputs into ret_val
    for (auto& o : outputs_) {
      ret_val.reserve(ret_val.size() + o.second.size());
      ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
    }
    return ret_val;
  }
  auto it = OpProtos().find(type_);
  PADDLE_ENFORCE(
      it != OpProtos().end(),
      "Operator %s not registered, cannot figure out intermediate outputs",
      type_);

  // get all OpProto::Var for outputs
  for (auto& o : it->second.outputs()) {
    // ignore all intermediate output
    if (o.intermediate()) continue;
    auto out = outputs_.find(o.name());
    if (out != outputs_.end()) {
      ret_val.reserve(ret_val.size() + out->second.size());
      ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
    }
  }
  return ret_val;
}

Q
Qiao Longfei 已提交
171
}  // namespace framework
L
liaogang 已提交
172
}  // namespace paddle