custom_operator_node.h 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/utils/any.h"

namespace egr {
class RunCustomOpNode : public GradNodeBase {
 public:
  // Constructor: configure fwd input tensors to grad node
  explicit RunCustomOpNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num,
                           const std::string& op_type)
      : GradNodeBase(bwd_in_slot_num, bwd_out_slot_num), op_type_(op_type) {
    VLOG(6) << "Construct RunCustomOpNode for op: " << op_type;
  }

  ~RunCustomOpNode() override {
    VLOG(6) << "Destruct RunCustomOpNode for op: " << op_type_;
  }

  // Functor: perform backward computations
39 40 41 42 43 44 45
  virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                               kSlotSmallVectorSize>
  operator()(  // NOLINT
      paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                           kSlotSmallVectorSize>& grads,  // NOLINT
      bool create_graph = false,
      bool is_new_grad = false)  // NOLINT
46
      override;
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

  std::string name() {
    return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_);
  }

  static std::vector<egr::TensorWrapper> ConstructTensorWrapper(
      const std::vector<paddle::experimental::Tensor>& fwd_var) {
    std::vector<egr::TensorWrapper> res;
    for (auto const& var : fwd_var) {
      res.emplace_back(var);
    }
    return res;
  }

  static std::vector<paddle::experimental::Tensor> Recover(
      std::vector<egr::TensorWrapper>* fwd_var) {
    std::vector<paddle::experimental::Tensor> res;
    for (size_t i = 0; i < fwd_var->size(); i++) {
65
      res.emplace_back(fwd_var->at(i).recover());
66 67 68 69
    }
    return res;
  }

70 71
  void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }

72 73
  void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; }

74 75 76 77 78 79
  std::shared_ptr<GradNodeBase> Copy() const override {
    auto copied_node =
        std::shared_ptr<RunCustomOpNode>(new RunCustomOpNode(*this));
    return copied_node;
  }

80 81 82 83 84 85 86 87 88 89 90
 public:
  std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_outs;
  std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_ins;
  std::unordered_map<int, int> grads2grad_in_map;

 private:
  std::vector<paddle::any> attrs_;
  std::string op_type_{""};
};

}  // namespace egr