grad_op_builder.cc 3.9 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
F
fengjiayi 已提交
11 12 13
WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either
express or implied. See the License for the specific language governing
permissions and limitations under the License. */
F
fengjiayi 已提交
14

15
#include "paddle/framework/grad_op_builder.h"
F
fengjiayi 已提交
16
#include "paddle/framework/op_registry.h"
17 18 19

namespace paddle {
namespace framework {
F
fengjiayi 已提交
20
enum class OpArgType { IN, OUT };
21

22
static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
Y
Yu Yang 已提交
23
                       bool is_grad, VariableNameMap* vars) {
24
  const auto& src_inout =
Q
qiaolongfei 已提交
25
      src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
Y
Yu Yang 已提交
26
  auto& dst_inout = *vars;
Y
Yu Yang 已提交
27
  auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
F
fengjiayi 已提交
28
  const auto& src_arg_list =
Y
Yu Yang 已提交
29
      src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
30
  for (const auto& arg : src_arg_list) {
31
    if (arg.not_in_gradient() && !is_grad) continue;
Q
qingqing01 已提交
32
    const std::string src_name = arg.name();
33
    std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
Q
qingqing01 已提交
34
    dst_inout[dst_name].reserve(src_inout.at(src_name).size());
35
    for (auto& var_name : src_inout.at(src_name)) {
Q
qingqing01 已提交
36
      std::string s = is_grad ? GradVarName(var_name) : var_name;
37
      dst_inout[dst_name].emplace_back(s);
38 39 40 41
    }
  }
}

42
OperatorBase* BuildGradOp(const OperatorBase* op) {
Y
Yu Yang 已提交
43 44
  auto& info = OpInfoMap::Instance().Get(op->Type());
  PADDLE_ENFORCE(info.HasGradientOp());
Y
Yi Wang 已提交
45

Y
Yu Yang 已提交
46 47
  VariableNameMap inputs;
  VariableNameMap outputs;
48 49 50 51
  TransOpArg(op, OpArgType::IN, false, &inputs);   // I
  TransOpArg(op, OpArgType::OUT, false, &inputs);  // O
  TransOpArg(op, OpArgType::OUT, true, &inputs);   // OG
  TransOpArg(op, OpArgType::IN, true, &outputs);   // IG
52

Y
Yu Yang 已提交
53 54
  auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_);
  return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs());
55
}
56

F
fengjiayi 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
static void TransOpDescArg(const OpDescBind* src_op, const OpArgType& src_type,
                           bool is_grad, OpDescBind* dst_op,
                           const OpArgType& dst_type) {
  PADDLE_ENFORCE(dst_op != nullptr,
                 "Protobuf desc of gradient op must be initialized first.");
  const auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
  const auto& src_arg_list =
      src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
  for (const auto& arg : src_arg_list) {
    if (arg.not_in_gradient() && !is_grad) continue;
    const std::string src_name = arg.name();
    std::vector<std::string> vars = src_type == OpArgType::IN
                                        ? src_op->Input(src_name)
                                        : src_op->Output(src_name);
    if (is_grad) {
      for (std::string& var : vars) {
        var = GradVarName(var);
      }
    }
    std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
    dst_type == OpArgType::IN ? dst_op->SetInput(dst_name, vars)
                              : dst_op->SetOutput(dst_name, vars);
  }
}

void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op) {
  auto& info = OpInfoMap::Instance().Get(forw_op->Type());
  PADDLE_ENFORCE(info.HasGradientOp());

  grad_op->SetType(info.grad_op_type_);

  TransOpDescArg(forw_op, OpArgType::IN, false, grad_op, OpArgType::IN);
  TransOpDescArg(forw_op, OpArgType::OUT, false, grad_op, OpArgType::IN);
  TransOpDescArg(forw_op, OpArgType::OUT, true, grad_op, OpArgType::IN);
  TransOpDescArg(forw_op, OpArgType::IN, true, grad_op, OpArgType::OUT);

  grad_op->SetAttrMap(forw_op->GetAttrMap());
}

96
}  // namespace framework
F
fengjiayi 已提交
97
}  // namespace paddle