net_op.cc 3.7 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
L
liaogang 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
/*
  Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

  Unless required by applicable law or agreed to in writing, software
  distributed under the License is distributed on an "AS IS" BASIS,
  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  See the License for the specific language governing permissions and
  limitations under the License.
*/

Y
Yan Chunwei 已提交
30
#include "paddle/operators/net_op.h"
Y
Yu Yang 已提交
31
#include <set>
D
dongzhihong 已提交
32
#include "paddle/framework/op_registry.h"
S
Superjom 已提交
33 34

namespace paddle {
Y
Yan Chunwei 已提交
35
namespace operators {
S
Superjom 已提交
36

37 38
const char NetOp::kAll[] = "all";

Y
Yu Yang 已提交
39
void NetOp::CompleteAddOp(bool calc) {
Y
Yu Yang 已提交
40 41
  add_op_done_ = true;
  if (!calc) return;
Y
Yu Yang 已提交
42 43
  std::set<std::string> input_set;
  std::set<std::string> output_set;
S
Superjom 已提交
44
  for (auto& op : ops_) {
Q
qiaolongfei 已提交
45
    for (auto& ipt : op->Inputs()) {
Y
Yu Yang 已提交
46
      for (auto& var_name : ipt.second) {
Q
qijun 已提交
47 48 49 50
        // If input variable has been in output set, then it will be
        // added into intermediate_outputs_. Otherwise, it will be
        // added into input set.
        if (Contains(output_set, var_name)) {
51
          intermediate_outputs_.insert(var_name);
Q
qijun 已提交
52 53
        } else {
          input_set.insert(var_name);
Y
Yu Yang 已提交
54
        }
Q
Qiao Longfei 已提交
55 56 57
      }
    }

Q
qiaolongfei 已提交
58
    for (auto& opt : op->Outputs()) {
Y
Yu Yang 已提交
59 60 61
      for (auto& var_name : opt.second) {
        output_set.insert(var_name);
      }
Q
Qiao Longfei 已提交
62
    }
S
Superjom 已提交
63
  }
64
  auto& inputs = inputs_[kAll];
Y
Yu Yang 已提交
65 66
  inputs.reserve(input_set.size());
  std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs));
67
  auto& outputs = outputs_[kAll];
Y
Yu Yang 已提交
68 69
  outputs.reserve(output_set.size());
  std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs));
D
dongzhihong 已提交
70
}
Q
Qiao Longfei 已提交
71

72
std::string NetOp::DebugStringEx(const framework::Scope* scope) const {
73
  std::ostringstream os;
74
  os << OperatorBase::DebugStringEx(scope) << std::endl;
75
  for (auto& op : ops_) {
76
    std::istringstream is(op->DebugStringEx(scope));
Y
Yu Yang 已提交
77 78 79
    for (std::string line; std::getline(is, line);) {
      os << "    " << line << std::endl;
    }
80 81 82 83
  }
  return os.str();
}

Y
Yu Yang 已提交
84 85
bool NetOp::IsNetOp() const { return true; }

86
std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
Y
Yu Yang 已提交
87 88 89 90 91 92
  std::vector<std::string> all;
  for (auto& pair : this->outputs_) {
    for (auto& var_name : pair.second) {
      all.push_back(var_name);
    }
  }
93
  if (has_intermediate) {
Y
Yu Yang 已提交
94
    return all;
95 96 97 98 99 100 101 102 103 104
  }
  std::vector<std::string> ret_val;
  for (auto& each : all) {
    if (!Contains(intermediate_outputs_, each)) {
      ret_val.push_back(each);
    }
  }
  return ret_val;
}

Y
Yu Yang 已提交
105 106
NetOp::NetOp(const std::string& type, const framework::VariableNameMap& inputs,
             const framework::VariableNameMap& outputs,
Y
Yu Yang 已提交
107
             const framework::AttributeMap& attrs)
Y
Yu Yang 已提交
108
    : framework::OperatorBase(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
109

Y
Yu Yang 已提交
110
std::unique_ptr<framework::OperatorBase> NetOp::Clone() const {
Y
Yu Yang 已提交
111 112 113
  PADDLE_ENFORCE(
      add_op_done_,
      "Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone");
Y
Yu Yang 已提交
114
  return std::unique_ptr<OperatorBase>(new NetOp(*this));
Y
Yu Yang 已提交
115 116
}

Y
Yan Chunwei 已提交
117
}  // namespace operators
S
Superjom 已提交
118
}  // namespace paddle